Skip to content

Commit

Permalink
Handling for iterables without 0-th element, fixes #2556 (#2629)
Browse files Browse the repository at this point in the history
* Handling for iterables without 0-th element, fixes #2556

* Improved accessing the first element for the case of big datasets
  • Loading branch information
Hiyorimi authored and mpenkov committed Oct 10, 2019
1 parent a7713aa commit 289a6ca
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gensim/sklearn_api/d2vmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def fit(self, X, y=None):
The trained model.
"""
if isinstance(X[0], doc2vec.TaggedDocument):
if isinstance([i for i in X[:1]][0], doc2vec.TaggedDocument):
d2v_sentences = X
else:
d2v_sentences = [doc2vec.TaggedDocument(words, [i]) for i, words in enumerate(X)]
Expand Down
57 changes: 57 additions & 0 deletions gensim/test/test_d2vmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2010 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Automated tests for checking D2VTransformer class.
"""

import unittest
import logging
from gensim.sklearn_api import D2VTransformer
from gensim.test.utils import common_texts


class IteratorForIterable:
"""Iterator capable of folding into list."""
def __init__(self, iterable):
self._data = iterable
self._index = 0

def __next__(self):
if len(self._data) > self._index:
result = self._data[self._index]
self._index += 1
return result
raise StopIteration


class IterableWithoutZeroElement:
"""
Iterable, emulating pandas.Series behaviour without 0-th element.
Equivalent to calling `series.index += 1`.
"""
def __init__(self, data):
self.data = data

def __getitem__(self, key):
if key == 0:
raise KeyError("Emulation of absence of item with key 0.")
return self.data[key]

def __iter__(self):
return IteratorForIterable(self.data)


class TestD2VTransformer(unittest.TestCase):
def TestWorksWithIterableNotHavingElementWithZeroIndex(self):
a = IterableWithoutZeroElement(common_texts)
transformer = D2VTransformer(min_count=1, size=5)
transformer.fit(a)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()

0 comments on commit 289a6ca

Please sign in to comment.