Skip to content

Commit

Permalink
scikit_learn wrapper for LSI Model in Gensim (#1244)
Browse files Browse the repository at this point in the history
* removed unnecessary keep_bocab_item import

* removed duplicate warnings import

* updated warning message for trim_rule

* added wrapper class for lsimodel

* removed unnecessary print statement

* added tests for lsi wrapper

* changed name from testPrintTopic to testModelSanity and made defaults explicit

* added pipeline example for LsiModel
  • Loading branch information
chinmayapancholi13 authored and tmylk committed Mar 28, 2017
1 parent 2afab97 commit a83e61b
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 34 deletions.
136 changes: 104 additions & 32 deletions docs/notebooks/sklearn_wrapper.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"This tutorial is about using gensim models as a part of your scikit learn workflow with the help of wrappers found at ```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel```"
"This tutorial is about using gensim models as a part of your scikit learn workflow with the help of wrappers found at ```gensim.sklearn_integration```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The wrapper available (as of now) are :\n",
"* LdaModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel.SklearnWrapperLdaModel```),which implements gensim's ```LdaModel``` in a scikit-learn interface"
"* LdaModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel.SklearnWrapperLdaModel```),which implements gensim's ```LdaModel``` in a scikit-learn interface\n",
"\n",
"* LsiModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_lsiModel.SklearnWrapperLsiModel```),which implements gensim's ```LsiModel``` in a scikit-learn interface"
]
},
{
Expand All @@ -38,7 +40,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 1,
"metadata": {
"collapsed": false
},
Expand All @@ -56,7 +58,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 2,
"metadata": {
"collapsed": true
},
Expand Down Expand Up @@ -85,7 +87,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 3,
"metadata": {
"collapsed": false
},
Expand All @@ -111,7 +113,7 @@
" [ 0.84210373, 0.15789627]])"
]
},
"execution_count": 22,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -129,7 +131,7 @@
"collapsed": true
},
"source": [
"### Integration with Sklearn"
"#### Integration with Sklearn"
]
},
{
Expand All @@ -141,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 4,
"metadata": {
"collapsed": false
},
Expand All @@ -157,7 +159,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 5,
"metadata": {
"collapsed": false
},
Expand All @@ -179,7 +181,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 6,
"metadata": {
"collapsed": false
},
Expand All @@ -202,7 +204,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 7,
"metadata": {
"collapsed": false
},
Expand All @@ -211,18 +213,18 @@
"data": {
"text/plain": [
"[(0,\n",
" u'0.085*\"abroad\" + 0.053*\"ciphertext\" + 0.042*\"arithmetic\" + 0.037*\"facts\" + 0.031*\"courtesy\" + 0.025*\"amolitor\" + 0.023*\"argue\" + 0.021*\"asking\" + 0.020*\"agree\" + 0.018*\"classified\"'),\n",
" u'0.025*\"456\" + 0.021*\"argue\" + 0.016*\"bitnet\" + 0.015*\"beastmaster\" + 0.014*\"cryptography\" + 0.013*\"false\" + 0.012*\"digex\" + 0.011*\"cover\" + 0.011*\"classified\" + 0.010*\"disk\"'),\n",
" (1,\n",
" u'0.098*\"asking\" + 0.075*\"cryptography\" + 0.068*\"abroad\" + 0.033*\"456\" + 0.025*\"argue\" + 0.022*\"bitnet\" + 0.017*\"false\" + 0.014*\"digex\" + 0.014*\"effort\" + 0.013*\"disk\"'),\n",
" u'0.142*\"abroad\" + 0.113*\"asking\" + 0.088*\"cryptography\" + 0.044*\"ciphertext\" + 0.043*\"arithmetic\" + 0.032*\"courtesy\" + 0.030*\"facts\" + 0.021*\"argue\" + 0.019*\"amolitor\" + 0.018*\"agree\"'),\n",
" (2,\n",
" u'0.023*\"accurate\" + 0.021*\"corporate\" + 0.013*\"clark\" + 0.012*\"chance\" + 0.009*\"consideration\" + 0.008*\"authentication\" + 0.008*\"dawson\" + 0.008*\"candidates\" + 0.008*\"basically\" + 0.008*\"assess\"'),\n",
" u'0.034*\"certain\" + 0.027*\"69\" + 0.025*\"book\" + 0.025*\"demand\" + 0.024*\"87\" + 0.024*\"cracking\" + 0.021*\"farm\" + 0.019*\"fierkelab\" + 0.015*\"face\" + 0.011*\"abroad\"'),\n",
" (3,\n",
" u'0.016*\"cryptography\" + 0.007*\"evans\" + 0.006*\"considering\" + 0.006*\"forgot\" + 0.006*\"built\" + 0.005*\"constitutional\" + 0.005*\"fly\" + 0.004*\"cellular\" + 0.004*\"computed\" + 0.004*\"digitized\"'),\n",
" u'0.017*\"decipher\" + 0.017*\"example\" + 0.016*\"cases\" + 0.016*\"follow\" + 0.008*\"considering\" + 0.006*\"forgot\" + 0.006*\"cellular\" + 0.005*\"evans\" + 0.005*\"computed\" + 0.005*\"cia\"'),\n",
" (4,\n",
" u'0.028*\"certain\" + 0.022*\"69\" + 0.021*\"book\" + 0.020*\"demand\" + 0.020*\"cracking\" + 0.020*\"87\" + 0.017*\"farm\" + 0.017*\"fierkelab\" + 0.015*\"face\" + 0.009*\"constitutional\"')]"
" u'0.022*\"accurate\" + 0.021*\"corporate\" + 0.013*\"chance\" + 0.012*\"clark\" + 0.009*\"consideration\" + 0.009*\"candidates\" + 0.008*\"dawson\" + 0.008*\"authentication\" + 0.008*\"assess\" + 0.008*\"attempt\"')]"
]
},
"execution_count": 26,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -239,12 +241,12 @@
"collapsed": true
},
"source": [
"### Example for Using Grid Search"
"#### Example for Using Grid Search"
]
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 8,
"metadata": {
"collapsed": false
},
Expand All @@ -256,7 +258,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 9,
"metadata": {
"collapsed": true
},
Expand All @@ -269,7 +271,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 10,
"metadata": {
"collapsed": false
},
Expand All @@ -280,16 +282,16 @@
"GridSearchCV(cv=5, error_score='raise',\n",
" estimator=SklearnWrapperLdaModel(alpha='symmetric', chunksize=2000, corpus=None,\n",
" decay=0.5, eta=None, eval_every=10, gamma_threshold=0.001,\n",
" id2word=<gensim.corpora.dictionary.Dictionary object at 0x7fb82cfbb7d0>,\n",
" id2word=<gensim.corpora.dictionary.Dictionary object at 0x7f42ccbebd10>,\n",
" iterations=50, minimum_probability=0.01, num_topics=5,\n",
" offset=1.0, passes=20, random_state=None, update_every=1),\n",
" fit_params={}, iid=True, n_jobs=1,\n",
" param_grid={'num_topics': (2, 3, 5, 10), 'iterations': (1, 20, 50)},\n",
" pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n",
" scoring=<function scorer at 0x7fb82cfaf938>, verbose=0)"
" scoring=<function scorer at 0x7f42cad12230>, verbose=0)"
]
},
"execution_count": 32,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -303,18 +305,18 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'iterations': 50, 'num_topics': 3}"
"{'iterations': 20, 'num_topics': 3}"
]
},
"execution_count": 33,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -327,14 +329,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example of Using Pipeline"
"#### Example of Using Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 12,
"metadata": {
"collapsed": true
"collapsed": false
},
"outputs": [],
"source": [
Expand All @@ -350,7 +352,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 13,
"metadata": {
"collapsed": false
},
Expand All @@ -362,7 +364,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 14,
"metadata": {
"collapsed": false
},
Expand Down Expand Up @@ -396,6 +398,76 @@
"print_features_pipe(pipe, id2word.values())\n",
"print pipe.score(corpus, data.target)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LsiModel"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use LsiModel begin with importing LsiModel wrapper"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklearnWrapperLsiModel"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Example of Using Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.13652819 0.00383696 0.02635504 -0.08454895 -0.02356143 0.60020084\n",
" 1.07026252 -0.04072257 0.43732847 0.54913549 -0.20242834 -0.21855402\n",
" -1.30546283 -0.08690711 0.17606255]\n",
"Positive features: 01101001B:1.07 comp.org.eff.talk.:0.60 red@redpoll.neoucom.edu:0.55 circuitry:0.44 >Pat:0.18 Fame.:0.14 Fame,:0.03 considered,:0.00\n",
"Negative features: internet...:-1.31 trawling:-0.22 hanging:-0.20 dome.:-0.09 Keach:-0.08 *best*:-0.04 comp.org.eff.talk,:-0.02\n",
"0.865771812081\n"
]
}
],
"source": [
"model=SklearnWrapperLsiModel(num_topics=15, id2word=id2word)\n",
"clf=linear_model.LogisticRegression(penalty='l2', C=0.1) #l2 penalty used\n",
"pipe = Pipeline((('features', model,), ('classifier', clf)))\n",
"pipe.fit(corpus, data.target)\n",
"print_features_pipe(pipe, id2word.values())\n",
"print pipe.score(corpus, data.target)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit a83e61b

Please sign in to comment.