-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathw2v_text_embed.py
72 lines (65 loc) · 2.33 KB
/
w2v_text_embed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import pandas as pd
import polars as pl
import re
class W2VEmbed:
"""
Use word2vec word embedding as a crude document embedding (bag of word2vec).
https://en.wikipedia.org/wiki/Word2vec
"""
def __init__(self, d) -> None:
"""
:param d: Pandas or Polars data frame with word2vec vectors. first column should be named "word", all other columns should contain the numeric embedding values.
"""
if not isinstance(d, pl.DataFrame):
d = pl.DataFrame(d)
assert d.columns[0] == "word"
# copy data into a map
v_cols = [c for c in d.columns if c != "word"]
self.n_dim = len(v_cols)
self.mp = dict()
for i in range(d.shape[0]):
row = d.row(i)
self.mp[row[0]] = np.array(row[1:(1 + self.n_dim)])
def transform_str(self, doc: str) -> pd.DataFrame:
"""
Crude word2vec based document embedder. Embeds documents to mean word2vec word encoding. For demonstration only.
:param X: string to be transformed
:return: single row data frame representing embedding of document
"""
assert isinstance(doc, str)
n_toks = 0
enc = np.zeros(self.n_dim)
doc = re.sub('[^0-9a-zA-Z]+', ' ', doc).strip()
toks = doc.split()
for tok in toks:
try:
v = self.mp[tok]
n_toks = n_toks + 1
enc = enc + v
except KeyError:
pass
if n_toks > 1:
enc = enc / n_toks
return pd.DataFrame(enc).transpose()
def transform(self, X) -> pd.DataFrame:
"""
Transform a data column.
:param X: iterable series (not data frame)
:return: pd.DataFrame
"""
# if a data frame or ndarray, get the first column
if isinstance(X, pl.DataFrame):
X = X[:, 0]
elif isinstance(X, pd.DataFrame):
X = X.iloc[:, 0].values
elif isinstance(X, np.ndarray):
assert len(X.shape) >= 1
assert len(X.shape) <= 2
if len(X.shape) > 1:
X = X[:, 0]
assert not isinstance(X, str) # common error
frames = [
self.transform_str(text) for text in X
]
return pd.concat(frames, ignore_index=True)