Skip to content

Commit

Permalink
Sklearn LDA wrapper now works in sklearn pipeline (#1213)
Browse files Browse the repository at this point in the history
  • Loading branch information
kris-singh authored and tmylk committed Mar 21, 2017
1 parent cc86005 commit 97cd64f
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 60 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ install:
- pip install annoy
- pip install testfixtures
- pip install unittest2
- pip install scikit-learn
- pip install Morfessor==2.0.2a4
- python setup.py install
script: python setup.py test
197 changes: 147 additions & 50 deletions docs/notebooks/sklearn_wrapper.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel import SklearnWrapperLdaModel"
"from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklearnWrapperLdaModel"
]
},
{
Expand All @@ -56,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 21,
"metadata": {
"collapsed": true
},
Expand Down Expand Up @@ -85,7 +85,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 22,
"metadata": {
"collapsed": false
},
Expand All @@ -100,21 +100,27 @@
{
"data": {
"text/plain": [
"[(0,\n",
" u'0.164*\"computer\" + 0.117*\"system\" + 0.105*\"graph\" + 0.061*\"server\" + 0.057*\"tree\" + 0.046*\"malfunction\" + 0.045*\"kernel\" + 0.045*\"complier\" + 0.043*\"loading\" + 0.039*\"hamiltonian\"'),\n",
" (1,\n",
" u'0.102*\"graph\" + 0.083*\"system\" + 0.072*\"tree\" + 0.064*\"server\" + 0.059*\"user\" + 0.059*\"computer\" + 0.057*\"trees\" + 0.056*\"eulerian\" + 0.055*\"node\" + 0.052*\"flow\"')]"
"array([[ 0.85275314, 0.14724686],\n",
" [ 0.12390183, 0.87609817],\n",
" [ 0.4612995 , 0.5387005 ],\n",
" [ 0.84924177, 0.15075823],\n",
" [ 0.49180096, 0.50819904],\n",
" [ 0.40086923, 0.59913077],\n",
" [ 0.28454427, 0.71545573],\n",
" [ 0.88776198, 0.11223802],\n",
" [ 0.84210373, 0.15789627]])"
]
},
"execution_count": 3,
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model=SklearnWrapperLdaModel(num_topics=2,id2word=dictionary,iterations=20, random_state=1)\n",
"model.fit(corpus)\n",
"model.print_topics(2)"
"model.print_topics(2)\n",
"model.transform(corpus)"
]
},
{
Expand All @@ -135,9 +141,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 23,
"metadata": {
"collapsed": true
"collapsed": false
},
"outputs": [],
"source": [
Expand All @@ -146,14 +152,14 @@
"from gensim.models.ldamodel import LdaModel\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"from sklearn.feature_extraction.text import CountVectorizer\n",
"from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel import SklearnWrapperLdaModel"
"from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklearnWrapperLdaModel"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 24,
"metadata": {
"collapsed": true
"collapsed": false
},
"outputs": [],
"source": [
Expand All @@ -173,9 +179,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 25,
"metadata": {
"collapsed": true
"collapsed": false
},
"outputs": [],
"source": [
Expand All @@ -196,7 +202,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 26,
"metadata": {
"collapsed": false
},
Expand All @@ -205,18 +211,18 @@
"data": {
"text/plain": [
"[(0,\n",
" u'0.018*\"cryptography\" + 0.018*\"face\" + 0.017*\"fierkelab\" + 0.008*\"abuse\" + 0.007*\"constitutional\" + 0.007*\"collection\" + 0.007*\"finish\" + 0.007*\"150\" + 0.007*\"fast\" + 0.006*\"difference\"'),\n",
" u'0.085*\"abroad\" + 0.053*\"ciphertext\" + 0.042*\"arithmetic\" + 0.037*\"facts\" + 0.031*\"courtesy\" + 0.025*\"amolitor\" + 0.023*\"argue\" + 0.021*\"asking\" + 0.020*\"agree\" + 0.018*\"classified\"'),\n",
" (1,\n",
" u'0.022*\"corporate\" + 0.022*\"accurate\" + 0.012*\"chance\" + 0.008*\"decipher\" + 0.008*\"example\" + 0.008*\"basically\" + 0.008*\"dawson\" + 0.008*\"cases\" + 0.008*\"consideration\" + 0.008*\"follow\"'),\n",
" u'0.098*\"asking\" + 0.075*\"cryptography\" + 0.068*\"abroad\" + 0.033*\"456\" + 0.025*\"argue\" + 0.022*\"bitnet\" + 0.017*\"false\" + 0.014*\"digex\" + 0.014*\"effort\" + 0.013*\"disk\"'),\n",
" (2,\n",
" u'0.034*\"argue\" + 0.031*\"456\" + 0.031*\"arithmetic\" + 0.024*\"courtesy\" + 0.020*\"beastmaster\" + 0.019*\"bitnet\" + 0.015*\"false\" + 0.015*\"classified\" + 0.014*\"cubs\" + 0.014*\"digex\"'),\n",
" u'0.023*\"accurate\" + 0.021*\"corporate\" + 0.013*\"clark\" + 0.012*\"chance\" + 0.009*\"consideration\" + 0.008*\"authentication\" + 0.008*\"dawson\" + 0.008*\"candidates\" + 0.008*\"basically\" + 0.008*\"assess\"'),\n",
" (3,\n",
" u'0.108*\"abroad\" + 0.089*\"asking\" + 0.060*\"cryptography\" + 0.035*\"certain\" + 0.030*\"ciphertext\" + 0.030*\"book\" + 0.028*\"69\" + 0.028*\"demand\" + 0.028*\"87\" + 0.027*\"cracking\"'),\n",
" u'0.016*\"cryptography\" + 0.007*\"evans\" + 0.006*\"considering\" + 0.006*\"forgot\" + 0.006*\"built\" + 0.005*\"constitutional\" + 0.005*\"fly\" + 0.004*\"cellular\" + 0.004*\"computed\" + 0.004*\"digitized\"'),\n",
" (4,\n",
" u'0.022*\"clark\" + 0.019*\"authentication\" + 0.017*\"candidates\" + 0.016*\"decryption\" + 0.015*\"attempt\" + 0.013*\"creation\" + 0.013*\"1993apr5\" + 0.013*\"acceptable\" + 0.013*\"algorithms\" + 0.013*\"employer\"')]"
" u'0.028*\"certain\" + 0.022*\"69\" + 0.021*\"book\" + 0.020*\"demand\" + 0.020*\"cracking\" + 0.020*\"87\" + 0.017*\"farm\" + 0.017*\"fierkelab\" + 0.015*\"face\" + 0.009*\"constitutional\"')]"
]
},
"execution_count": 7,
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -233,72 +239,163 @@
"collapsed": true
},
"source": [
"#### Using together with Scikit learn's Logistic Regression"
"### Example for Using Grid Search"
]
},
{
"cell_type": "markdown",
"metadata": {},
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"Now lets try Sklearn's logistic classifier to classify the given categories into two types.Ideally we should get postive weights when cryptography is talked about and negative when baseball is talked about."
"from sklearn.model_selection import GridSearchCV\n",
"from gensim.models.coherencemodel import CoherenceModel"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 31,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn import linear_model"
"def scorer(estimator, X,y=None):\n",
" goodcm = CoherenceModel(model=estimator, texts= texts, dictionary=estimator.id2word, coherence='c_v')\n",
" return goodcm.get_coherence()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"GridSearchCV(cv=5, error_score='raise',\n",
" estimator=SklearnWrapperLdaModel(alpha='symmetric', chunksize=2000, corpus=None,\n",
" decay=0.5, eta=None, eval_every=10, gamma_threshold=0.001,\n",
" id2word=<gensim.corpora.dictionary.Dictionary object at 0x7fb82cfbb7d0>,\n",
" iterations=50, minimum_probability=0.01, num_topics=5,\n",
" offset=1.0, passes=20, random_state=None, update_every=1),\n",
" fit_params={}, iid=True, n_jobs=1,\n",
" param_grid={'num_topics': (2, 3, 5, 10), 'iterations': (1, 20, 50)},\n",
" pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n",
" scoring=<function scorer at 0x7fb82cfaf938>, verbose=0)"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"obj=SklearnWrapperLdaModel(id2word=dictionary,num_topics=5,passes=20)\n",
"parameters = {'num_topics':(2, 3, 5, 10), 'iterations':(1,20,50)}\n",
"model = GridSearchCV(obj, parameters, scoring=scorer, cv=5)\n",
"model.fit(corpus)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'iterations': 50, 'num_topics': 3}"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example of Using Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 34,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def print_features(clf, vocab, n=10):\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn import linear_model\n",
"def print_features_pipe(clf, vocab, n=10):\n",
" ''' Better printing for sorted list '''\n",
" coef = clf.coef_[0]\n",
" coef = clf.named_steps['classifier'].coef_[0]\n",
" print coef\n",
" print 'Positive features: %s' % (' '.join(['%s:%.2f' % (vocab[j], coef[j]) for j in np.argsort(coef)[::-1][:n] if coef[j] > 0]))\n",
" print 'Negative features: %s' % (' '.join(['%s:%.2f' % (vocab[j], coef[j]) for j in np.argsort(coef)[:n] if coef[j] < 0]))"
" print 'Negative features: %s' % (' '.join(['%s:%.2f' % (vocab[j], coef[j]) for j in np.argsort(coef)[:n] if coef[j] < 0]))\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"id2word=Dictionary(map(lambda x : x.split(),data.data))\n",
"corpus = [id2word.doc2bow(i.split()) for i in data.data]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 38,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:gensim.models.ldamodel:too few updates, training might not converge; consider increasing the number of passes or iterations to improve accuracy\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Positive features: clipper:1.50 code:1.24 key:1.04 encryption:0.95 chip:0.37 nsa:0.37 government:0.36 uk:0.36 org:0.23 cryptography:0.23\n",
"Negative features: baseball:-1.32 game:-0.71 year:-0.61 team:-0.38 edu:-0.27 games:-0.26 players:-0.23 ball:-0.17 season:-0.14 phillies:-0.11\n"
"[ -2.95020466e-01 -1.04115352e-01 5.19570267e-01 1.03817059e-01\n",
" 2.72881013e-02 1.35738501e-02 1.89246630e-13 1.89246630e-13\n",
" 1.89246630e-13 1.89246630e-13 1.89246630e-13 1.89246630e-13\n",
" 1.89246630e-13 1.89246630e-13 1.89246630e-13]\n",
"Positive features: Fame,:0.52 Keach:0.10 comp.org.eff.talk,:0.03 comp.org.eff.talk.:0.01 >Pat:0.00 dome.:0.00 internet...:0.00 trawling:0.00 hanging:0.00 red@redpoll.neoucom.edu:0.00\n",
"Negative features: Fame.:-0.30 considered,:-0.10\n",
"0.531040268456\n"
]
}
],
"source": [
"clf=linear_model.LogisticRegression(penalty='l1', C=0.1) #l1 penalty used\n",
"clf.fit(X,data.target)\n",
"print_features(clf,vocab)"
"model=SklearnWrapperLdaModel(num_topics=15,id2word=id2word,iterations=50, random_state=37)\n",
"clf=linear_model.LogisticRegression(penalty='l2', C=0.1) #l2 penalty used\n",
"pipe = Pipeline((('features', model,), ('classifier', clf)))\n",
"pipe.fit(corpus, data.target)\n",
"print_features_pipe(pipe, id2word.values())\n",
"print pipe.score(corpus, data.target)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -317,7 +414,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
"version": "2.7.13"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 97cd64f

Please sign in to comment.