-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
module for maintaining probability distributions
- Loading branch information
1 parent
0b5c38d
commit f61168c
Showing
1 changed file
with
100 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/*++ | ||
Copyright (c) 2017 Microsoft Corporation | ||
Module Name: | ||
distribution.h | ||
Abstract: | ||
Probabiltiy distribution | ||
Author: | ||
Nikolaj Bjorner (nbjorner) 2023-4-12 | ||
Notes: | ||
Distribution class works by pushing identifiers with associated scores. | ||
After they have been pushed, you can access a random element using choose | ||
or you can enumerate the elements in random order, sorted by the score probability. | ||
--*/ | ||
#pragma once | ||
|
||
#include "vector.h" | ||
|
||
class distribution { | ||
|
||
random_gen m_random; | ||
svector<std::pair<unsigned, unsigned>> m_elems; | ||
unsigned m_sum = 0; | ||
|
||
unsigned choose(unsigned sum) { | ||
unsigned s = m_random(sum); | ||
for (auto const& [j, score] : m_elems) { | ||
if (s < score) | ||
return j; | ||
s -= score; | ||
} | ||
UNREACHABLE(); | ||
return 0; | ||
} | ||
|
||
public: | ||
|
||
distribution(unsigned seed): m_random(seed) {} | ||
|
||
void reset() { | ||
m_elems.reset(); | ||
m_sum = 0; | ||
} | ||
|
||
bool empty() const { | ||
return m_elems.empty(); | ||
} | ||
|
||
void push(unsigned id, unsigned score) { | ||
SASSERT(score > 0); | ||
if (score > 0) { | ||
m_elems.push_back({id, score}); | ||
m_sum += score; | ||
} | ||
} | ||
|
||
/** | ||
\brief choose an element at random with probability proportional to the score | ||
relative to the sum of scores of other. | ||
*/ | ||
unsigned choose() { | ||
return m_elems[choose(m_sum)].first; | ||
} | ||
|
||
class iterator { | ||
distribution& d; | ||
unsigned m_sz = 0; | ||
unsigned m_sum = 0; | ||
unsigned m_index = 0; | ||
void next_index() { | ||
if (0 == m_sz) | ||
return; | ||
m_index = d.choose(m_sum); | ||
} | ||
public: | ||
iterator(distribution& d, bool start): d(d), m_sz(start?d.m_elems.size():0), m_sum(d.m_sum) { | ||
next_index(); | ||
} | ||
unsigned operator*() const { return d.m_elems[m_index].first; } | ||
iterator operator++() { | ||
m_sum -= d.m_elems[m_index].second; | ||
--m_sz; | ||
std::swap(d.m_elems[m_index], d.m_elems[d.m_elems.size() - 1]); | ||
next_index(); | ||
} | ||
bool operator==(iterator const& other) const { return m_sz == other.m_sz; } | ||
bool operator!=(iterator const& other) const { return m_sz != other.m_sz; } | ||
}; | ||
|
||
iterator begin() { return iterator(*this, true); } | ||
iterator end() { return iterator(*this, false); } | ||
}; |