diff --git a/lisc/objects/base.py b/lisc/objects/base.py index dbf19e6a..8ba211af 100644 --- a/lisc/objects/base.py +++ b/lisc/objects/base.py @@ -36,9 +36,22 @@ def __init__(self): def __getitem__(self, key): - """Index into Base object, accessing Term.""" + """Index into Base object, accessing Term. - return self.get_term(self.get_index(key)) + Parameters + ---------- + key : str or int + Label or index of the element to extract. + """ + + return self.get_term(self.get_index(key) if isinstance(key, str) else key) + + + def __iter__(self): + """Allow for iterating across the object by stepping through terms.""" + + for ind in range(self.n_terms): + yield self.get_term(ind) @property diff --git a/lisc/objects/counts.py b/lisc/objects/counts.py index 6f446f69..e7155f01 100644 --- a/lisc/objects/counts.py +++ b/lisc/objects/counts.py @@ -50,6 +50,28 @@ def __init__(self): self.meta_data = None + def __getitem__(self, keys): + """Index into Counts object, accessing count. + + Parameters + ---------- + keys : list of (str, int) + Labels or indices for the data to access. + """ + + if not self.has_data: + raise IndexError('No data is available - cannot proceed.') + + if not isinstance(keys, list): + return ValueError('Input keys do not match the object.') + + ind0 = self.terms['A'].get_index(keys[0]) if isinstance(keys[0], str) else keys[0] + ind1 = self.terms['B' if self.terms['B'].terms else 'A'].get_index(keys[1]) \ + if isinstance(keys[1], str) else keys[1] + + return self.counts[ind0, ind1] + + @property def has_data(self): """Indicator for if the object has collected data.""" diff --git a/lisc/objects/words.py b/lisc/objects/words.py index 80ff7448..de8569f0 100644 --- a/lisc/objects/words.py +++ b/lisc/objects/words.py @@ -54,6 +54,13 @@ def __getitem__(self, label): return self.results[ind] + def __iter__(self): + """Allow for iterating across the object by stepping through collected results.""" + + for result in self.results: + yield result + + @property def has_data(self): """Indicator for if the object has collected data.""" diff --git a/lisc/tests/objects/test_base.py b/lisc/tests/objects/test_base.py index 46e70775..878512e2 100755 --- a/lisc/tests/objects/test_base.py +++ b/lisc/tests/objects/test_base.py @@ -20,6 +20,11 @@ def test_get_item(tbase_terms): assert isinstance(out, Term) assert out.label == 'label0' +def test_iter(tbase_terms): + + for term in tbase_terms: + assert isinstance(term, Term) + def test_get_index(tbase_terms): ind = tbase_terms.get_index('label0') diff --git a/lisc/tests/objects/test_counts.py b/lisc/tests/objects/test_counts.py index d6d940f8..d213e470 100755 --- a/lisc/tests/objects/test_counts.py +++ b/lisc/tests/objects/test_counts.py @@ -47,6 +47,14 @@ def compute_scores(counts): assert counts.score.any() assert counts.score_info['type'] == score_type +def check_dunders(counts): + + label0 = counts.terms['A'].labels[0] + label1 = counts.terms['B' if counts.terms['B'].terms else 'A'].labels[0] + + out = counts[label0, label1] + assert out == self.counts[0, 0] + def check_funcs(counts): counts.check_data() diff --git a/lisc/tests/objects/test_words.py b/lisc/tests/objects/test_words.py index 1c68452d..4c22524e 100755 --- a/lisc/tests/objects/test_words.py +++ b/lisc/tests/objects/test_words.py @@ -51,9 +51,17 @@ def test_collect(test_req): assert words.has_data assert len(words.results) == len(terms) + check_dunders(words) check_funcs(words) drop_data(words, retmax+1) +def check_dunders(words): + + for ind, result in enumerate(words): + ind += 1 + assert result + assert ind == len(words.results) + def check_funcs(words): words.check_data()