Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HMM labeled training issue #4

Closed
AKuederle opened this issue Nov 15, 2022 · 1 comment
Closed

HMM labeled training issue #4

AKuederle opened this issue Nov 15, 2022 · 1 comment

Comments

@AKuederle
Copy link
Member

@nilsroth I think there was a bug in the training code for labeled training of the combined model

In this code snippet, you are attempting to set the only the edges that are not there yet.
However, you assume that the e.g. the entry for state transition 15 to 16 is at position (15, 16) in the transition matrix. However, because of the "sorting" bug in pomegranate this is not the case.
In result, some edges are not written and some edges are overwritten, eventhough they exist already.

# Add missing transitions which will "connect" transition-hmm and stride-hmm
for trans in transitions:
# if edge already exists, skip
if not model_untrained.dense_transition_matrix()[trans[0], trans[1]]:
add_transition(model_untrained, ["s%d" % (trans[0]), "s%d" % (trans[1])], 0.1)

I fixed that by getting the existing transitions based on the actual statenames and checking based on that.

Could you check, if the same issue exists in your old training code?

With my fix, the train-hmm example now produces an actual usable HMM :)

@AKuederle
Copy link
Member Author

Here is the updated code:

existing_transitions = {(start.name, end.name) for start, end in new_model.graph.edges()}
missing_transitions = transitions - existing_transitions
# Add missing transitions which will "connect" transition-hmm and stride-hmm
# We initialize with a very small probability, so that the model can learn the correct values in the next
# step.
for trans in missing_transitions:
add_transition(new_model, trans, 0.1)

I also changed extract_transitions_starts_stops_from_hidden_state_sequence to return a set with the actual state names and not the state indices to make that work.

Still a little bit scary that we depend on this default naming of pg... But it is what it is...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant