Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Value expectation and 1st order CKY #93

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
12 changes: 12 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ def test_simple(data, seed):
dist.kmax(5)
dist.count

val_func = torch.rand(*vals.shape, 10)
E_val = dist.expected_value(val_func)
struct_vals = (
edges.unsqueeze(-1)
.mul(val_func.unsqueeze(0))
.reshape(*edges.shape[:2], -1, val_func.shape[-1])
.sum(2)
)
assert torch.isclose(
E_val, log_probs.exp().unsqueeze(-1).mul(struct_vals).sum(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious. Why not just make this the implementation of expected value? It seems just as good and perhaps more efficient.y

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, maybe I'm confused but isn't this enumerating over all possible structures explicitly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, my comment is confusing.

I think a valid way of computing an expectation over any "part-level value" is to first compute the marginals (.marginals()) and then doing an elementwise mul (.mul) and then summing. Doesn't that give you the same thing as the semiring?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow, I didn't realize this! I just tested it out and it appears to be more efficient for larger structure sizes. I guess this is due to the fast log semiring implementation? I'll update things to use this approach instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that is right... I haven't thought about this too much, but my guess is that this is just better on GPU hardware since the expectation is batched at the end. But it seems worth understand when this works. I don't think you can compute Entropy this way? (but I might be wrong)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I also don't think entropy can be done this way -- I just tested it out and the results didn't match the semiring. I will switch to this implementation in the latest commit and get rid of the value semiring.

Fwiw I ran a quick speed comparison you might be interested in:

B, N, C = 4, 200, 10

phis = torch.randn(B,N,C,C).cuda()
vals = torch.randn(B,N,C,C,10).cuda()

Results from running w/ genbmm

%%timeit
LinearChainCRF(phis).expected_value(vals)
>>> 100 loops, best of 3: 6.34 ms per loop

%%timeit
LinearChainCRF(phis).marginals.unsqueeze(-1).mul(vals).reshape(B,-1,vals.shape[-1]).sum(1)
>>> 100 loops, best of 3: 5.64 ms per loop

Results from running w/o genbmm

%%timeit
LinearChainCRF(phis).expected_value(vals)
>>> 100 loops, best of 3: 9.67 ms per loop

%%timeit
LinearChainCRF(phis).marginals.unsqueeze(-1).mul(vals).reshape(B,-1,vals.shape[-1]).sum(1)
>>> 100 loops, best of 3: 8.83 ms per loop

).all(), "Efficient expected value not equal to enumeration"


@given(data(), integers(min_value=1, max_value=20))
@settings(max_examples=50, deadline=None)
Expand Down
92 changes: 63 additions & 29 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .alignment import Alignment
from .deptree import DepTree, deptree_nonproj, deptree_part
from .cky_crf import CKY_CRF
from .full_cky_crf import Full_CKY_CRF
from .semirings import (
LogSemiring,
MaxSemiring,
Expand Down Expand Up @@ -91,9 +92,7 @@ def cross_entropy(self, other):
cross entropy (*batch_shape*)
"""

return self._struct(CrossEntropySemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)
return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths)

def kl(self, other):
"""
Expand All @@ -105,9 +104,7 @@ def kl(self, other):
Returns:
cross entropy (*batch_shape*)
"""
return self._struct(KLDivergenceSemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)
return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths)

@lazy_property
def max(self):
Expand Down Expand Up @@ -140,9 +137,7 @@ def kmax(self, k):
kmax (*k x batch_shape*)
"""
with torch.enable_grad():
return self._struct(KMaxSemiring(k)).sum(
self.log_potentials, self.lengths, _raw=True
)
return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True)

def topk(self, k):
r"""
Expand All @@ -155,9 +150,7 @@ def topk(self, k):
kmax (*k x batch_shape x event_shape*)
"""
with torch.enable_grad():
return self._struct(KMaxSemiring(k)).marginals(
self.log_potentials, self.lengths, _raw=True
)
return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True)

@lazy_property
def mode(self):
Expand All @@ -179,16 +172,33 @@ def marginals(self):

@lazy_property
def count(self):
"Compute the log-partition function."
"Compute the total number of structures in the CRF support set."
ones = torch.ones_like(self.log_potentials)
ones[self.log_potentials.eq(-float("inf"))] = 0
return self._struct(StdSemiring).sum(ones, self.lengths)

def expected_value(self, values):
"""
Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z.

Parameters:
values (:class: torch.FloatTensor): (*batch_shape x *event_shape x *value_shape), assigns a value to each
part of the structure. `values` can have 0 or more trailing dimensions in addition to the `event_shape`,
which allows for computing the expected value of, say, a vector valued function.

Returns:
expected value (*batch_shape, *value_shape)
"""
# For these "part-level" expectations, this can be computed by multiplying the marginals element-wise
# on the values and summing. This is faster than the semiring because of FastLogSemiring.
# (w/o genbmm it's about the same.)
ps = self.marginals
ps_bcast = ps.reshape(*ps.shape, *((1,) * (len(values.shape) - len(ps.shape))))
return ps_bcast.mul(values).reshape(ps.shape[0], -1, *values.shape[len(ps.shape) :]).sum(1)

def gumbel_crf(self, temperature=1.0):
with torch.enable_grad():
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(
self.log_potentials, self.lengths
)
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths)
return st_gumbel

# @constraints.dependent_property
Expand All @@ -214,16 +224,18 @@ def sample(self, sample_shape=torch.Size()):
Returns:
samples (*sample_shape x batch_shape x event_shape*)
"""
assert len(sample_shape) == 1
nsamples = sample_shape[0]
batch_size = MultiSampledSemiring.batch_size
if type(sample_shape) == int:
nsamples = sample_shape
else:
assert len(sample_shape) == 1
nsamples = sample_shape[0]
samples = []
for k in range(nsamples):
if k % 10 == 0:
sample = self._struct(MultiSampledSemiring).marginals(
self.log_potentials, lengths=self.lengths
)
if k % batch_size == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, sorry this is my fault. 10 is a global constant. Let's put it on MultiSampledSemiring.

sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths)
sample = sample.detach()
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1)
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % batch_size) + 1)
samples.append(tmp_sample)
return torch.stack(samples)

Expand Down Expand Up @@ -301,9 +313,7 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None):
super().__init__(log_potentials, lengths)

def _struct(self, sr=None):
return self.struct(
sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap
)
return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap)


class HMM(StructDistribution):
Expand Down Expand Up @@ -411,6 +421,32 @@ class TreeCRF(StructDistribution):
struct = CKY_CRF


class FullTreeCRF(StructDistribution):
r"""
Represents a 1st-order span parser with NT nonterminals. Implemented using a
fast CKY algorithm.

For a description see:

* Inside-Outside Algorithm, by Michael Collins

Event shape is of the form:

Parameters:
log_potentials (tensor) : event_shape (*N x N x N x NT x NT x NT*), e.g.
:math:`\phi(i, j, k, A_i^j \rightarrow B_i^k C_{k+1}^j)`
lengths (long tensor) : batch shape integers for length masking.

Implementation uses width-batched, forward-pass only

* Parallel Time: :math:`O(N)` parallel merges.
* Forward Memory: :math:`O(N^3 NT^3)`

Compact representation: *N x N x N x NT x NT x NT* long tensor (Same)
"""
struct = Full_CKY_CRF


class SentCFG(StructDistribution):
"""
Represents a full generative context-free grammar with
Expand Down Expand Up @@ -440,9 +476,7 @@ def __init__(self, log_potentials, lengths=None):
event_shape = log_potentials[0].shape[1:]
self.log_potentials = log_potentials
self.lengths = lengths
super(StructDistribution, self).__init__(
batch_shape=batch_shape, event_shape=event_shape
)
super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape)


class NonProjectiveDependencyCRF(StructDistribution):
Expand Down
115 changes: 115 additions & 0 deletions torch_struct/full_cky_crf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
from .helpers import _Struct

A, B = 0, 1


class Full_CKY_CRF(_Struct):
def _check_potentials(self, edge, lengths=None):
batch, N, N1, N2, NT, NT1, NT2 = self._get_dimension(edge)
assert (
N == N1 == N2 and NT == NT1 == NT2
), f"Want N:{N} == N1:{N1} == N2:{N2} and NT:{NT} == NT1:{NT1} == NT2:{NT2}"
edge = self.semiring.convert(edge)
semiring_shape = edge.shape[:-7]
if lengths is None:
lengths = torch.LongTensor([N] * batch).to(edge.device)

return edge, semiring_shape, batch, N, NT, lengths

def logpartition(self, scores, lengths=None, force_grad=False, cache=True):
sr = self.semiring

# Scores.shape = *sshape, B, N, N, N, NT, NT, NT
# w/ semantics [ *semiring stuff, b, i, j, k, A, B, C]
# where b is batch index, i is left endpoint, j is right endpoint, k is splitpoint, with rule A -> B C
scores, sshape, batch, N, NT, lengths = self._check_potentials(scores, lengths)
sshape, sdims = list(sshape), list(range(len(sshape))) # usually [0]
S, b = len(sdims), batch

# Initialize data structs
LEFT, RIGHT = 0, 1
L_DIM, R_DIM = S + 1, S + 2 # one and two to the right of the batch dim

# Initialize the base cases with scores from diagonal i=j=k, A=B=C
term_scores = (
scores.diagonal(0, L_DIM, R_DIM) # diag i,j now at dim -1
.diagonal(0, L_DIM, -1) # diag of k with that gives i=j=k, now at dim -1
.diagonal(0, -4, -3) # diag of A, B, now at dim -1, ijk moves to -2
.diagonal(0, -3, -1) # diag of C with that gives A=B=C
)
assert term_scores.shape[S + 1 :] == (
N,
NT,
), f"{term_scores.shape[S + 1 :]} == {(N, NT)}"
alpha_left = term_scores
alpha_right = term_scores
alphas = [[alpha_left], [alpha_right]]

# Run vectorized inside alg
for w in range(1, N):
# Scores
# What we want is a tensor with:
# shape: *sshape, batch, (N-w), NT, w, NT, NT
# w/ semantics: [...batch, (i,j=i+w), A, k, B, C]
# where (i,j=i+w) means the diagonal of trees nodes with width w
# Shape: *sshape, batch, N, NT, NT, NT, (N-w) w/ semantics [ ...batch, k, A, B, C, (i,j=i+w)]
score = scores.diagonal(w, L_DIM, R_DIM) # get diagonal scores

score = score.permute(
sdims + [-6, -1, -4, -5, -3, -2]
) # move diag (-1) dim and head NT (-4) dim to front
score = score[..., :w, :, :] # remove illegal splitpoints
assert score.shape[S:] == (
batch,
N - w,
NT,
w,
NT,
NT,
), f"{score.shape[S:]} == {(b, N-w, NT, w, NT, NT)}"

# Sums of left subtrees
# Shape: *sshape, batch, (N-w), w, NT
# where L[..., i, d, B] is the sum of subtrees up to (i,j=(i+d),B)
left = slice(None, N - w) # left indices
L = torch.stack(alphas[LEFT][:w], dim=-2)[..., left, :, :]

# Sums of right subtrees
# Shape: *sshape, batch, (N-w), w, NT
# where R[..., h, d, C] is the sum of subtrees up to (i=(N-h-d),j=(N-h),C)
right = slice(w, None) # right indices
R = torch.stack(list(reversed(alphas[RIGHT][:w])), dim=-2)[..., right, :, :]

# Broadcast them both to match missing dims in score
# Left B is duplicated for all head and right symbols A C
L_bcast = L.reshape(list(sshape) + [b, N - w, 1, w, NT, 1])

# Right C is duplicated for all head and left symbols A B
R_bcast = R.reshape(list(sshape) + [b, N - w, 1, w, 1, NT])

# Now multiply all the scores and sum over k, B, C dimensions (the last three dims)
assert sr.times(score, L_bcast, R_bcast).shape == tuple(
list(sshape) + [b, N - w, NT, w, NT, NT]
)
# sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast))))
sum_prod_w = sr.sum(
sr.times(score, L_bcast, R_bcast).reshape(*score.shape[:-3], -1)
)
assert sum_prod_w.shape[S:] == (
b,
N - w,
NT,
), f"{sum_prod_w.shape[S:]} == {(b,N-w, NT)}"

pad = sr.zero_(torch.ones(sshape + [b, w, NT]).to(sum_prod_w.device))
sum_prod_w_left = torch.cat([sum_prod_w, pad], dim=-2)
sum_prod_w_right = torch.cat([pad, sum_prod_w], dim=-2)
alphas[LEFT].append(sum_prod_w_left)
alphas[RIGHT].append(sum_prod_w_right)

final = sr.sum(torch.stack(alphas[LEFT], dim=-2))[
..., 0, :
] # sum out root symbol
log_Z = final[:, torch.arange(batch), lengths - 1]
return log_Z, [scores]
Loading