From 0395cfc26d5809ef31c2f5613016ffd60ed5281c Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Fri, 2 Sep 2022 09:08:40 +0200 Subject: [PATCH] Consolidate and freeze symbols (#11352) * Consolidate and freeze symbols Instead of having symbol values defined in three potentially conflicting places (`spacy.attrs`, `spacy.parts_of_speech`, `spacy.symbols`), define all symbols in `spacy.symbols` and reference those values in `spacy.attrs` and `spacy.parts_of_speech`. Remove deprecated and placeholder symbols from `spacy.attrs.IDS`. Make `spacy.attrs.NAMES` and `spacy.symbols.NAMES` reverse dicts rather than lists in order to support future use of hash values in `attr_id_t`. Minor changes: * Use `uint64_t` for attrs in `Doc.to_array` to support future use of hash values * Remove unneeded attrs filter for error message in `Doc.to_array` * Remove unused attr `SENT_END` * Handle dynamic size of attr_id_t in Doc.to_array * Undo added warnings * Refactor to make Doc.to_array more similar to Doc.from_array * Improve refactoring --- spacy/parts_of_speech.pxd | 2 +- spacy/strings.pyx | 57 +++++++++++++++++++++---------------- spacy/tests/test_symbols.py | 1 - 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/spacy/parts_of_speech.pxd b/spacy/parts_of_speech.pxd index 22a571be7b0..01f116ea688 100644 --- a/spacy/parts_of_speech.pxd +++ b/spacy/parts_of_speech.pxd @@ -8,7 +8,7 @@ cpdef enum univ_pos_t: ADV = symbols.ADV AUX = symbols.AUX CONJ = symbols.CONJ - CCONJ = symbols.CCONJ # U20 + CCONJ = symbols.CCONJ # U20 DET = symbols.DET INTJ = symbols.INTJ NOUN = symbols.NOUN diff --git a/spacy/strings.pyx b/spacy/strings.pyx index 28e06a2ecea..a80985f6ff2 100644 --- a/spacy/strings.pyx +++ b/spacy/strings.pyx @@ -32,16 +32,34 @@ cdef class StringStore: for string in strings: self.add(string) - def __getitem__(self, string_or_hash: Union[str, int]) -> Union[str, int]: - """Retrieve a string from a given hash. If a string - is passed as the input, add it to the store and return - its hash. + def __getitem__(self, object string_or_id): + """Retrieve a string from a given hash, or vice versa. - string_or_hash (int / str): The hash value to lookup or the string to store. - RETURNS (str / int): The stored string or the hash of the newly added string. + string_or_id (bytes, str or uint64): The value to encode. + Returns (str / uint64): The value to be retrieved. """ - if isinstance(string_or_hash, str): - return self.add(string_or_hash) + cdef hash_t str_hash + cdef Utf8Str* utf8str = NULL + + if isinstance(string_or_id, str): + if len(string_or_id) == 0: + return 0 + + # Return early if the string is found in the symbols LUT. + symbol = SYMBOLS_BY_STR.get(string_or_id, None) + if symbol is not None: + return symbol + else: + return hash_string(string_or_id) + elif isinstance(string_or_id, bytes): + return hash_utf8(string_or_id, len(string_or_id)) + elif _try_coerce_to_hash(string_or_id, &str_hash): + if str_hash == 0: + return "" + elif str_hash in SYMBOLS_BY_INT: + return SYMBOLS_BY_INT[str_hash] + else: + utf8str = self._map.get(str_hash) else: return self._get_interned_str(string_or_hash) @@ -111,24 +129,13 @@ cdef class StringStore: if isinstance(string_or_hash, str): return string_or_hash else: - return self._get_interned_str(string_or_hash) + # TODO: Raise an error instead + return self._map.get(string_or_id) is not NULL - def items(self) -> List[Tuple[str, int]]: - """Iterate over the stored strings and their hashes in insertion order. - - RETURNS: A list of string-hash pairs. - """ - # Even though we internally store the hashes as keys and the strings as - # values, we invert the order in the public API to keep it consistent with - # the implementation of the `__iter__` method (where we wish to iterate over - # the strings in the store). - cdef int i - pairs = [None] * self._keys.size() - for i in range(self._keys.size()): - str_hash = self._keys[i] - utf8str = self._map.get(str_hash) - pairs[i] = (self._decode_str_repr(utf8str), str_hash) - return pairs + if str_hash in SYMBOLS_BY_INT: + return True + else: + return self._map.get(str_hash) is not NULL def keys(self) -> List[str]: """Iterate over the stored strings in insertion order. diff --git a/spacy/tests/test_symbols.py b/spacy/tests/test_symbols.py index 2c2fcef755e..fb034accac2 100644 --- a/spacy/tests/test_symbols.py +++ b/spacy/tests/test_symbols.py @@ -1,5 +1,4 @@ import pytest - from spacy.symbols import IDS, NAMES V3_SYMBOLS = {