Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update document and segment word count #10449

Merged
merged 4 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,9 +1414,13 @@ def create_segment(cls, args: dict, document: Document, dataset: Dataset):
created_by=current_user.id,
)
if document.doc_form == "qa_model":
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]

db.session.add(segment_document)
# update document word count
document.word_count += segment_document.word_count
db.session.add(document)
db.session.commit()

# save vector index
Expand All @@ -1435,6 +1439,7 @@ def create_segment(cls, args: dict, document: Document, dataset: Dataset):
@classmethod
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
lock_name = "multi_add_segment_lock_document_id_{}".format(document.id)
increment_word_count = 0
with redis_client.lock(lock_name, timeout=600):
embedding_model = None
if dataset.indexing_technique == "high_quality":
Expand All @@ -1460,7 +1465,10 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
# calc embedding use tokens
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
if document.doc_form == "qa_model":
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]])
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
Expand All @@ -1478,6 +1486,8 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas
)
if document.doc_form == "qa_model":
segment_document.answer = segment_item["answer"]
segment_document.word_count += len(segment_item["answer"])
increment_word_count += segment_document.word_count
db.session.add(segment_document)
segment_data_list.append(segment_document)

Expand All @@ -1486,7 +1496,9 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas
keywords_list.append(segment_item["keywords"])
else:
keywords_list.append(None)

# update document word count
document.word_count += increment_word_count
db.session.add(document)
try:
# save vector index
VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset)
Expand Down Expand Up @@ -1527,17 +1539,25 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document
else:
raise ValueError("Can't update disabled segment")
try:
word_count_change = segment.word_count
content = segment_update_entity.content
if segment.content == content:
segment.word_count = len(content)
if document.doc_form == "qa_model":
segment.answer = segment_update_entity.answer
segment.word_count += len(segment_update_entity.answer)
word_count_change = segment.word_count - word_count_change
if segment_update_entity.keywords:
segment.keywords = segment_update_entity.keywords
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.add(segment)
db.session.commit()
# update document word count
if word_count_change != 0:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# update segment index task
if segment_update_entity.enabled:
VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset)
Expand All @@ -1554,7 +1574,10 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document
)

# calc embedding use tokens
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
if document.doc_form == "qa_model":
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
Expand All @@ -1569,6 +1592,12 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document
segment.disabled_by = None
if document.doc_form == "qa_model":
segment.answer = segment_update_entity.answer
segment.word_count += len(segment_update_entity.answer)
word_count_change = segment.word_count - word_count_change
# update document word count
if word_count_change != 0:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
db.session.add(segment)
db.session.commit()
# update segment vector index
Expand Down Expand Up @@ -1597,6 +1626,9 @@ def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: D
redis_client.setex(indexing_cache_key, 600, 1)
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
db.session.delete(segment)
# update document word count
document.word_count -= segment.word_count
db.session.add(document)
db.session.commit()


Expand Down
7 changes: 6 additions & 1 deletion api/tasks/batch_create_segment_to_index_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def batch_create_segment_to_index_task(
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)

word_count_change = 0
for segment in content:
content = segment["content"]
doc_id = str(uuid.uuid4())
Expand Down Expand Up @@ -86,8 +86,13 @@ def batch_create_segment_to_index_task(
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
db.session.add(segment_document)
document_segments.append(segment_document)
# update document word count
dataset_document.word_count += word_count_change
db.session.add(dataset_document)
# add index to db
indexing_runner = IndexingRunner()
indexing_runner.batch_add_segments(document_segments, dataset)
Expand Down