Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated DocumentTermMatrix implementation #205

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

### Fixed

## [0.54.0]

### Updated
- Improved the performance of BM25/Keywords-based indices for >10M documents. Introduced new kwargs of `min_term_freq` and `max_terms` in `RT.get_keywords` to reduce the size of the vocabulary. See `?RT.get_keywords` for more information.

## [0.53.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.53.0"
version = "0.54.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
48 changes: 38 additions & 10 deletions ext/RAGToolsExperimentalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,41 +110,69 @@ function Base.hcat(d1::RT.DocumentTermMatrix{<:AbstractSparseMatrix},
end

"""
document_term_matrix(documents::AbstractVector{<:AbstractVector{<:AbstractString}})
RT.document_term_matrix(
documents::AbstractVector{<:AbstractVector{T}};
min_term_freq::Int = 1, max_terms::Int = typemax(Int)) where {T <: AbstractString}

Builds a sparse matrix of term frequencies and document lengths from the given vector of documents wrapped in type `DocumentTermMatrix`.

Expects a vector of preprocessed (tokenized) documents, where each document is a vector of strings (clean tokens).

Returns: `DocumentTermMatrix`

# Arguments
- `documents`: A vector of documents, where each document is a vector of terms (clean tokens).
- `min_term_freq`: The minimum frequency a term must have to be included in the vocabulary, eg, `min_term_freq = 2` means only terms that appear at least twice will be included.
- `max_terms`: The maximum number of terms to include in the vocabulary, eg, `max_terms = 100` means only the 100 most frequent terms will be included.

# Example
```
documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]]
dtm = document_term_matrix(documents)
```
"""
function RT.document_term_matrix(documents::AbstractVector{<:AbstractVector{<:AbstractString}})
T = eltype(documents) |> eltype
vocab = convert(Vector{T}, unique(vcat(documents...)))
vocab_lookup = Dict{T, Int}(t => i for (i, t) in enumerate(vocab))
function RT.document_term_matrix(
documents::AbstractVector{<:AbstractVector{T}};
min_term_freq::Int = 1, max_terms::Int = typemax(Int)) where {T <: AbstractString}
## Calculate term frequencies, sort descending
counts = Dict{T, Int}()
@inbounds for doc in documents
for term in doc
counts[term] = get(counts, term, 0) + 1
end
end
counts = sort(collect(counts), by = x -> -x[2]) |> Base.Fix2(first, max_terms) |>
Base.Fix1(filter!, x -> x[2] >= min_term_freq)
## Create vocabulary
vocab = convert(Vector{T}, getindex.(counts, 1))
vocab_lookup = Dict{T, Int}(term => i for (i, term) in enumerate(vocab))
N = length(documents)
doc_freq = zeros(Int, length(vocab))
term_freq = spzeros(Float32, N, length(vocab))
doc_lengths = zeros(Float32, N)
## Term frequency matrix to be recorded via its sparse entries: I, J, V
# term_freq = spzeros(Float32, N, length(vocab))
I, J, V = Int[], Int[], Float32[]

unique_terms = Set{eltype(vocab)}()
sizehint!(unique_terms, 1000)
for di in eachindex(documents)
unique_terms = Set{eltype(vocab)}()
empty!(unique_terms)
doc = documents[di]
for t in doc
@inbounds for t in doc
doc_lengths[di] += 1
tid = vocab_lookup[t]
term_freq[di, tid] += 1
tid = get(vocab_lookup, t, nothing)
tid === nothing && continue
push!(I, di)
push!(J, tid)
push!(V, 1.0f0)
if !(t in unique_terms)
doc_freq[tid] += 1
push!(unique_terms, t)
end
end
end
## combine repeated terms with `+`
term_freq = sparse(I, J, V, N, length(vocab), +)
idf = @. log(1.0f0 + (N - doc_freq + 0.5f0) / (doc_freq + 0.5f0))
sumdl = sum(doc_lengths)
doc_rel_length = sumdl == 0 ? zeros(Float32, N) : doc_lengths ./ (sumdl / N)
Expand Down
15 changes: 8 additions & 7 deletions ext/SnowballPromptingToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ using Snowball
RT._stem(stemmer::Snowball.Stemmer, text::AbstractString) = Snowball.stem(stemmer, text)

"""
get_keywords(processor::KeywordsProcessor, docs::AbstractVector{<:AbstractString};
RT.get_keywords(
processor::RT.KeywordsProcessor, docs::AbstractVector{<:AbstractString};
verbose::Bool = true,
stemmer = nothing,
stopwords::Set{String} = Set(STOPWORDS),
stopwords::Set{String} = Set(RT.STOPWORDS),
return_keywords::Bool = false,
min_length::Integer = 3,
min_term_freq::Int = 1, max_terms::Int = typemax(Int),
kwargs...)

Generate a `DocumentTermMatrix` from a vector of `docs` using the provided `stemmer` and `stopwords`.
Expand All @@ -29,6 +31,8 @@ Generate a `DocumentTermMatrix` from a vector of `docs` using the provided `stem
- `stopwords`: A set of stopwords to remove. Default is `Set(STOPWORDS)`.
- `return_keywords`: A boolean flag for returning the keywords. Default is `false`. Useful for query processing in search time.
- `min_length`: The minimum length of the keywords. Default is `3`.
- `min_term_freq`: The minimum frequency a term must have to be included in the vocabulary, eg, `min_term_freq = 2` means only terms that appear at least twice will be included.
- `max_terms`: The maximum number of terms to include in the vocabulary, eg, `max_terms = 100` means only the 100 most frequent terms will be included.
"""
function RT.get_keywords(
processor::RT.KeywordsProcessor, docs::AbstractVector{<:AbstractString};
Expand All @@ -37,16 +41,13 @@ function RT.get_keywords(
stopwords::Set{String} = Set(RT.STOPWORDS),
return_keywords::Bool = false,
min_length::Integer = 3,
min_term_freq::Int = 1, max_terms::Int = typemax(Int),
kwargs...)
## check if extension is available
ext = Base.get_extension(PromptingTools, :RAGToolsExperimentalExt)
if isnothing(ext)
error("You need to also import LinearAlgebra and SparseArrays to use this function")
end
## ext = Base.get_extension(PromptingTools, :SnowballPromptingToolsExt)
## if isnothing(ext)
## error("You need to also import Snowball.jl to use this function")
## end
## Preprocess text into tokens
stemmer = !isnothing(stemmer) ? stemmer : Snowball.Stemmer("english")
# Single-threaded as stemmer is not thread-safe
Expand All @@ -56,7 +57,7 @@ function RT.get_keywords(
return_keywords && return keywords

## Create DTM
dtm = RT.document_term_matrix(keywords)
dtm = RT.document_term_matrix(keywords; min_term_freq, max_terms)

verbose && @info "Done processing DocumentTermMatrix."
return dtm
Expand Down
2 changes: 1 addition & 1 deletion src/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ function get_chunks(chunker::AbstractChunker,
# split into chunks by recursively trying the separators provided
# if you want to start simple - just do `split(text,"\n\n")`
doc_chunks = PT.recursive_splitter(doc_raw, separators; max_length) .|> strip |>
x -> filter(!isempty, x)
Base.Fix1(filter!, !isempty)
# skip if no chunks found
isempty(doc_chunks) && continue
append!(output_chunks, doc_chunks)
Expand Down
26 changes: 26 additions & 0 deletions test/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,32 @@ end
@test Set(dtm.vocab) == Set(["this", "test", "document", "anoth", "more", "text"])
@test size(dtm.tf) == (2, 6)

# Test for KeywordsProcessor with min_term_freq and max_terms
docs_freq = [
"apple banana cherry apple",
"banana date fig grape",
"apple banana cherry date",
"elephant fig grape"
]
processor_freq = KeywordsProcessor()

# Test with min_term_freq = 2
dtm_freq = get_keywords(processor_freq, docs_freq; min_term_freq = 2)
@test Set(dtm_freq.vocab) ==
Set(["appl", "banana", "cherri", "date", "fig", "grape"])
@test size(dtm_freq.tf) == (4, 6)

# Test with max_terms = 3
dtm_max = get_keywords(processor_freq, docs_freq; max_terms = 3)
@test length(dtm_max.vocab) == 3
@test size(dtm_max.tf) == (4, 3)

# Test with both min_term_freq = 2 and max_terms = 2
dtm_both = get_keywords(processor_freq, docs_freq; min_term_freq = 2, max_terms = 2)
@test length(dtm_both.vocab) == 2
@test size(dtm_both.tf) == (4, 2)
@test all(sum(dtm_both.tf, dims = 1) .>= 2)

# Test for KeywordsProcessor with custom stemmer and stopwords
custom_stemmer = Snowball.Stemmer("french")
dtm_custom = get_keywords(
Expand Down
Loading