From cf79aa14850917a5ce72f7107fb87dfc78a31b48 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Thu, 21 Dec 2023 13:21:17 +0000 Subject: [PATCH] feat: add support for single meta dict in `TextFileToDocument` (#6606) * add support for single meta dict * reno * reno * mypy * extract to function * docstring * mypy --- haystack/components/converters/txt.py | 22 ++++++++------ haystack/components/converters/utils.py | 24 ++++++++++++++- ...tadata-txt-converter-a02bf90c60262701.yaml | 4 +++ test/components/converters/test_utils.py | 29 +++++++++++++++++++ 4 files changed, 69 insertions(+), 10 deletions(-) create mode 100644 releasenotes/notes/single-metadata-txt-converter-a02bf90c60262701.yaml create mode 100644 test/components/converters/test_utils.py diff --git a/haystack/components/converters/txt.py b/haystack/components/converters/txt.py index 08c48b97c0..1af25cfe1e 100644 --- a/haystack/components/converters/txt.py +++ b/haystack/components/converters/txt.py @@ -4,7 +4,7 @@ from haystack import Document, component from haystack.dataclasses import ByteStream -from haystack.components.converters.utils import get_bytestream_from_source +from haystack.components.converters.utils import get_bytestream_from_source, normalize_metadata logger = logging.getLogger(__name__) @@ -38,25 +38,29 @@ def __init__(self, encoding: str = "utf-8"): self.encoding = encoding @component.output_types(documents=List[Document]) - def run(self, sources: List[Union[str, Path, ByteStream]], meta: Optional[List[Dict[str, Any]]] = None): + def run( + self, + sources: List[Union[str, Path, ByteStream]], + meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + ): """ Convert text files to Documents. :param sources: A list of paths to text files or ByteStream objects. Note that if an encoding is specified in the metadata of a ByteStream, it will override the component's default. - :param meta: Optional list of metadata to attach to the Documents. - The length of the list must match the number of sources. Defaults to `None`. + :param meta: Optional metadata to attach to the Documents. + This value can be either a list of dictionaries or a single dictionary. + If it's a single dictionary, its content is added to the metadata of all produced Documents. + If it's a list, the length of the list must match the number of sources, because the two lists will be zipped. + Defaults to `None`. :return: A dictionary containing a list of Document objects under the 'documents' key. """ documents = [] - if meta is None: - meta = [{}] * len(sources) - elif len(sources) != len(meta): - raise ValueError("The length of the metadata list must match the number of sources.") + meta_list = normalize_metadata(meta, sources_count=len(sources)) - for source, metadata in zip(sources, meta): + for source, metadata in zip(sources, meta_list): try: bytestream = get_bytestream_from_source(source) except Exception as e: diff --git a/haystack/components/converters/utils.py b/haystack/components/converters/utils.py index d5040635e2..908a0a0bd5 100644 --- a/haystack/components/converters/utils.py +++ b/haystack/components/converters/utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Union +from typing import List, Union, Dict, Any, Optional from haystack.dataclasses import ByteStream @@ -18,3 +18,25 @@ def get_bytestream_from_source(source: Union[str, Path, ByteStream]) -> ByteStre bs.meta["file_path"] = str(source) return bs raise ValueError(f"Unsupported source type {type(source)}") + + +def normalize_metadata( + meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], sources_count: int +) -> List[Dict[str, Any]]: + """ + Given all the possible value of the meta input for a converter (None, dictionary or list of dicts), + makes sure to return a list of dictionaries of the correct length for the converter to use. + + :param meta: the meta input of the converter, as-is + :sources_count: the number of sources the converter received + :returns: a list of dictionaries of the make length as the sources list + """ + if meta is None: + return [{}] * sources_count + if isinstance(meta, dict): + return [meta] * sources_count + if isinstance(meta, list): + if sources_count != len(meta): + raise ValueError("The length of the metadata list must match the number of sources.") + return meta + raise ValueError("meta must be either None, a dictionary or a list of dictionaries.") diff --git a/releasenotes/notes/single-metadata-txt-converter-a02bf90c60262701.yaml b/releasenotes/notes/single-metadata-txt-converter-a02bf90c60262701.yaml new file mode 100644 index 0000000000..68c90c59cb --- /dev/null +++ b/releasenotes/notes/single-metadata-txt-converter-a02bf90c60262701.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Adds support for single metadata dictionary input in `TextFileToDocument``. diff --git a/test/components/converters/test_utils.py b/test/components/converters/test_utils.py new file mode 100644 index 0000000000..68133d193f --- /dev/null +++ b/test/components/converters/test_utils.py @@ -0,0 +1,29 @@ +import pytest +from haystack.components.converters.utils import normalize_metadata + + +def test_normalize_metadata_None(): + assert normalize_metadata(None, sources_count=1) == [{}] + assert normalize_metadata(None, sources_count=3) == [{}, {}, {}] + + +def test_normalize_metadata_single_dict(): + assert normalize_metadata({"a": 1}, sources_count=1) == [{"a": 1}] + assert normalize_metadata({"a": 1}, sources_count=3) == [{"a": 1}, {"a": 1}, {"a": 1}] + + +def test_normalize_metadata_list_of_right_size(): + assert normalize_metadata([{"a": 1}], sources_count=1) == [{"a": 1}] + assert normalize_metadata([{"a": 1}, {"b": 2}, {"c": 3}], sources_count=3) == [{"a": 1}, {"b": 2}, {"c": 3}] + + +def test_normalize_metadata_list_of_wrong_size(): + with pytest.raises(ValueError, match="The length of the metadata list must match the number of sources."): + normalize_metadata([{"a": 1}], sources_count=3) + with pytest.raises(ValueError, match="The length of the metadata list must match the number of sources."): + assert normalize_metadata([{"a": 1}, {"b": 2}, {"c": 3}], sources_count=1) + + +def test_normalize_metadata_other_type(): + with pytest.raises(ValueError, match="meta must be either None, a dictionary or a list of dictionaries."): + normalize_metadata(({"a": 1},), sources_count=1)