-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Value expectation and 1st order CKY #93
Open
teffland
wants to merge
14
commits into
harvardnlp:master
Choose a base branch
from
teffland:value_expectation
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
90fc546
starting on full crf
teffland e863fbb
working full cky crf
498c964
debugging full cky
049dbd4
more changes
3c5dfbc
[wip] add expectation semiring
73c8d7b
add expected value semiring and test
6e1704a
add full cky crfclear
2d0abe8
Update full_cky_crf.py
teffland 297209a
Update helpers.py
srush 657fbc6
Update distributions.py
srush 6edceb0
address review suggestions
19982f7
fix doc string errors
71004b2
switch value expectation to elementwise mul and reduce
cded5e1
darglint ignore logpartition docstring mismatch
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
from .alignment import Alignment | ||
from .deptree import DepTree, deptree_nonproj, deptree_part | ||
from .cky_crf import CKY_CRF | ||
from .full_cky_crf import Full_CKY_CRF | ||
from .semirings import ( | ||
LogSemiring, | ||
MaxSemiring, | ||
|
@@ -91,9 +92,7 @@ def cross_entropy(self, other): | |
cross entropy (*batch_shape*) | ||
""" | ||
|
||
return self._struct(CrossEntropySemiring).sum( | ||
[self.log_potentials, other.log_potentials], self.lengths | ||
) | ||
return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths) | ||
|
||
def kl(self, other): | ||
""" | ||
|
@@ -105,9 +104,7 @@ def kl(self, other): | |
Returns: | ||
cross entropy (*batch_shape*) | ||
""" | ||
return self._struct(KLDivergenceSemiring).sum( | ||
[self.log_potentials, other.log_potentials], self.lengths | ||
) | ||
return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths) | ||
|
||
@lazy_property | ||
def max(self): | ||
|
@@ -140,9 +137,7 @@ def kmax(self, k): | |
kmax (*k x batch_shape*) | ||
""" | ||
with torch.enable_grad(): | ||
return self._struct(KMaxSemiring(k)).sum( | ||
self.log_potentials, self.lengths, _raw=True | ||
) | ||
return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True) | ||
|
||
def topk(self, k): | ||
r""" | ||
|
@@ -155,9 +150,7 @@ def topk(self, k): | |
kmax (*k x batch_shape x event_shape*) | ||
""" | ||
with torch.enable_grad(): | ||
return self._struct(KMaxSemiring(k)).marginals( | ||
self.log_potentials, self.lengths, _raw=True | ||
) | ||
return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True) | ||
|
||
@lazy_property | ||
def mode(self): | ||
|
@@ -179,16 +172,33 @@ def marginals(self): | |
|
||
@lazy_property | ||
def count(self): | ||
"Compute the log-partition function." | ||
"Compute the total number of structures in the CRF support set." | ||
ones = torch.ones_like(self.log_potentials) | ||
ones[self.log_potentials.eq(-float("inf"))] = 0 | ||
return self._struct(StdSemiring).sum(ones, self.lengths) | ||
|
||
def expected_value(self, values): | ||
""" | ||
Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z. | ||
|
||
Parameters: | ||
values (:class: torch.FloatTensor): (*batch_shape x *event_shape x *value_shape), assigns a value to each | ||
part of the structure. `values` can have 0 or more trailing dimensions in addition to the `event_shape`, | ||
which allows for computing the expected value of, say, a vector valued function. | ||
|
||
Returns: | ||
expected value (*batch_shape, *value_shape) | ||
""" | ||
# For these "part-level" expectations, this can be computed by multiplying the marginals element-wise | ||
# on the values and summing. This is faster than the semiring because of FastLogSemiring. | ||
# (w/o genbmm it's about the same.) | ||
ps = self.marginals | ||
ps_bcast = ps.reshape(*ps.shape, *((1,) * (len(values.shape) - len(ps.shape)))) | ||
return ps_bcast.mul(values).reshape(ps.shape[0], -1, *values.shape[len(ps.shape) :]).sum(1) | ||
|
||
def gumbel_crf(self, temperature=1.0): | ||
with torch.enable_grad(): | ||
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals( | ||
self.log_potentials, self.lengths | ||
) | ||
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths) | ||
return st_gumbel | ||
|
||
# @constraints.dependent_property | ||
|
@@ -214,16 +224,18 @@ def sample(self, sample_shape=torch.Size()): | |
Returns: | ||
samples (*sample_shape x batch_shape x event_shape*) | ||
""" | ||
assert len(sample_shape) == 1 | ||
nsamples = sample_shape[0] | ||
batch_size = MultiSampledSemiring.batch_size | ||
if type(sample_shape) == int: | ||
nsamples = sample_shape | ||
else: | ||
assert len(sample_shape) == 1 | ||
nsamples = sample_shape[0] | ||
samples = [] | ||
for k in range(nsamples): | ||
if k % 10 == 0: | ||
sample = self._struct(MultiSampledSemiring).marginals( | ||
self.log_potentials, lengths=self.lengths | ||
) | ||
if k % batch_size == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah, sorry this is my fault. 10 is a global constant. Let's put it on MultiSampledSemiring. |
||
sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths) | ||
sample = sample.detach() | ||
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1) | ||
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % batch_size) + 1) | ||
samples.append(tmp_sample) | ||
return torch.stack(samples) | ||
|
||
|
@@ -301,9 +313,7 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None): | |
super().__init__(log_potentials, lengths) | ||
|
||
def _struct(self, sr=None): | ||
return self.struct( | ||
sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap | ||
) | ||
return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap) | ||
|
||
|
||
class HMM(StructDistribution): | ||
|
@@ -411,6 +421,32 @@ class TreeCRF(StructDistribution): | |
struct = CKY_CRF | ||
|
||
|
||
class FullTreeCRF(StructDistribution): | ||
r""" | ||
Represents a 1st-order span parser with NT nonterminals. Implemented using a | ||
fast CKY algorithm. | ||
|
||
For a description see: | ||
|
||
* Inside-Outside Algorithm, by Michael Collins | ||
|
||
Event shape is of the form: | ||
|
||
Parameters: | ||
log_potentials (tensor) : event_shape (*N x N x N x NT x NT x NT*), e.g. | ||
:math:`\phi(i, j, k, A_i^j \rightarrow B_i^k C_{k+1}^j)` | ||
lengths (long tensor) : batch shape integers for length masking. | ||
|
||
Implementation uses width-batched, forward-pass only | ||
|
||
* Parallel Time: :math:`O(N)` parallel merges. | ||
* Forward Memory: :math:`O(N^3 NT^3)` | ||
|
||
Compact representation: *N x N x N x NT x NT x NT* long tensor (Same) | ||
""" | ||
struct = Full_CKY_CRF | ||
|
||
|
||
class SentCFG(StructDistribution): | ||
""" | ||
Represents a full generative context-free grammar with | ||
|
@@ -440,9 +476,7 @@ def __init__(self, log_potentials, lengths=None): | |
event_shape = log_potentials[0].shape[1:] | ||
self.log_potentials = log_potentials | ||
self.lengths = lengths | ||
super(StructDistribution, self).__init__( | ||
batch_shape=batch_shape, event_shape=event_shape | ||
) | ||
super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape) | ||
|
||
|
||
class NonProjectiveDependencyCRF(StructDistribution): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import torch | ||
from .helpers import _Struct | ||
|
||
A, B = 0, 1 | ||
|
||
|
||
class Full_CKY_CRF(_Struct): | ||
def _check_potentials(self, edge, lengths=None): | ||
batch, N, N1, N2, NT, NT1, NT2 = self._get_dimension(edge) | ||
assert ( | ||
N == N1 == N2 and NT == NT1 == NT2 | ||
), f"Want N:{N} == N1:{N1} == N2:{N2} and NT:{NT} == NT1:{NT1} == NT2:{NT2}" | ||
edge = self.semiring.convert(edge) | ||
semiring_shape = edge.shape[:-7] | ||
if lengths is None: | ||
lengths = torch.LongTensor([N] * batch).to(edge.device) | ||
|
||
return edge, semiring_shape, batch, N, NT, lengths | ||
|
||
def logpartition(self, scores, lengths=None, force_grad=False, cache=True): | ||
sr = self.semiring | ||
|
||
# Scores.shape = *sshape, B, N, N, N, NT, NT, NT | ||
# w/ semantics [ *semiring stuff, b, i, j, k, A, B, C] | ||
# where b is batch index, i is left endpoint, j is right endpoint, k is splitpoint, with rule A -> B C | ||
scores, sshape, batch, N, NT, lengths = self._check_potentials(scores, lengths) | ||
sshape, sdims = list(sshape), list(range(len(sshape))) # usually [0] | ||
S, b = len(sdims), batch | ||
|
||
# Initialize data structs | ||
LEFT, RIGHT = 0, 1 | ||
L_DIM, R_DIM = S + 1, S + 2 # one and two to the right of the batch dim | ||
|
||
# Initialize the base cases with scores from diagonal i=j=k, A=B=C | ||
term_scores = ( | ||
scores.diagonal(0, L_DIM, R_DIM) # diag i,j now at dim -1 | ||
.diagonal(0, L_DIM, -1) # diag of k with that gives i=j=k, now at dim -1 | ||
.diagonal(0, -4, -3) # diag of A, B, now at dim -1, ijk moves to -2 | ||
.diagonal(0, -3, -1) # diag of C with that gives A=B=C | ||
) | ||
assert term_scores.shape[S + 1 :] == ( | ||
N, | ||
NT, | ||
), f"{term_scores.shape[S + 1 :]} == {(N, NT)}" | ||
alpha_left = term_scores | ||
alpha_right = term_scores | ||
alphas = [[alpha_left], [alpha_right]] | ||
|
||
# Run vectorized inside alg | ||
for w in range(1, N): | ||
# Scores | ||
# What we want is a tensor with: | ||
# shape: *sshape, batch, (N-w), NT, w, NT, NT | ||
# w/ semantics: [...batch, (i,j=i+w), A, k, B, C] | ||
# where (i,j=i+w) means the diagonal of trees nodes with width w | ||
# Shape: *sshape, batch, N, NT, NT, NT, (N-w) w/ semantics [ ...batch, k, A, B, C, (i,j=i+w)] | ||
score = scores.diagonal(w, L_DIM, R_DIM) # get diagonal scores | ||
|
||
score = score.permute( | ||
sdims + [-6, -1, -4, -5, -3, -2] | ||
) # move diag (-1) dim and head NT (-4) dim to front | ||
score = score[..., :w, :, :] # remove illegal splitpoints | ||
assert score.shape[S:] == ( | ||
batch, | ||
N - w, | ||
NT, | ||
w, | ||
NT, | ||
NT, | ||
), f"{score.shape[S:]} == {(b, N-w, NT, w, NT, NT)}" | ||
|
||
# Sums of left subtrees | ||
# Shape: *sshape, batch, (N-w), w, NT | ||
# where L[..., i, d, B] is the sum of subtrees up to (i,j=(i+d),B) | ||
left = slice(None, N - w) # left indices | ||
L = torch.stack(alphas[LEFT][:w], dim=-2)[..., left, :, :] | ||
|
||
# Sums of right subtrees | ||
# Shape: *sshape, batch, (N-w), w, NT | ||
# where R[..., h, d, C] is the sum of subtrees up to (i=(N-h-d),j=(N-h),C) | ||
right = slice(w, None) # right indices | ||
R = torch.stack(list(reversed(alphas[RIGHT][:w])), dim=-2)[..., right, :, :] | ||
|
||
# Broadcast them both to match missing dims in score | ||
# Left B is duplicated for all head and right symbols A C | ||
L_bcast = L.reshape(list(sshape) + [b, N - w, 1, w, NT, 1]) | ||
|
||
# Right C is duplicated for all head and left symbols A B | ||
R_bcast = R.reshape(list(sshape) + [b, N - w, 1, w, 1, NT]) | ||
|
||
# Now multiply all the scores and sum over k, B, C dimensions (the last three dims) | ||
assert sr.times(score, L_bcast, R_bcast).shape == tuple( | ||
list(sshape) + [b, N - w, NT, w, NT, NT] | ||
) | ||
# sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast)))) | ||
sum_prod_w = sr.sum( | ||
sr.times(score, L_bcast, R_bcast).reshape(*score.shape[:-3], -1) | ||
) | ||
assert sum_prod_w.shape[S:] == ( | ||
b, | ||
N - w, | ||
NT, | ||
), f"{sum_prod_w.shape[S:]} == {(b,N-w, NT)}" | ||
|
||
pad = sr.zero_(torch.ones(sshape + [b, w, NT]).to(sum_prod_w.device)) | ||
sum_prod_w_left = torch.cat([sum_prod_w, pad], dim=-2) | ||
sum_prod_w_right = torch.cat([pad, sum_prod_w], dim=-2) | ||
alphas[LEFT].append(sum_prod_w_left) | ||
alphas[RIGHT].append(sum_prod_w_right) | ||
|
||
final = sr.sum(torch.stack(alphas[LEFT], dim=-2))[ | ||
..., 0, : | ||
] # sum out root symbol | ||
log_Z = final[:, torch.arange(batch), lengths - 1] | ||
return log_Z, [scores] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious. Why not just make this the implementation of expected value? It seems just as good and perhaps more efficient.y
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, maybe I'm confused but isn't this enumerating over all possible structures explicitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry, my comment is confusing.
I think a valid way of computing an expectation over any "part-level value" is to first compute the marginals (.marginals()) and then doing an elementwise mul (.mul) and then summing. Doesn't that give you the same thing as the semiring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh wow, I didn't realize this! I just tested it out and it appears to be more efficient for larger structure sizes. I guess this is due to the fast log semiring implementation? I'll update things to use this approach instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think that is right... I haven't thought about this too much, but my guess is that this is just better on GPU hardware since the expectation is batched at the end. But it seems worth understand when this works. I don't think you can compute Entropy this way? (but I might be wrong)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I also don't think entropy can be done this way -- I just tested it out and the results didn't match the semiring. I will switch to this implementation in the latest commit and get rid of the value semiring.
Fwiw I ran a quick speed comparison you might be interested in:
Results from running w/ genbmm
Results from running w/o genbmm