-
Notifications
You must be signed in to change notification settings - Fork 4
/
hybrid.py
59 lines (47 loc) · 2.19 KB
/
hybrid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import random
import sys
from .strategy import Strategy
from .DATE import DATESampling
from main import initialize_sampler
from utils import timer_func
class HybridSampling(Strategy):
def __init__(self, args):
super(HybridSampling, self).__init__(args)
self.subsamps = [initialize_sampler(subsamp, args) for subsamp in args.subsamplings.split("/")]
self.weights = [float(weight) for weight in args.weights.split("/")]
assert round(sum(self.weights), 10) == 1
assert len(self.subsamps) == len(self.weights)
# self.available_indices = None # Needed!
def set_data(self, data):
super(HybridSampling, self).set_data(data)
for subsamp in self.subsamps:
subsamp.set_data(data)
def set_weights(self, weights):
self.weights = weights
def get_weights(self):
return self.weights
def set_uncertainty_module(self, uncertainty_module):
super(HybridSampling, self).set_uncertainty_module(uncertainty_module)
for subsamp in self.subsamps:
subsamp.uncertainty_module = uncertainty_module
@timer_func
def query(self, k):
self.ks = [round(k*weight) for weight in self.weights[:-1]]
self.ks.append(k - sum(self.ks))
self.chosen = []
self.each_chosen = {}
trained_DATE_available = False
for subsamp, num_samp in zip(self.subsamps, self.ks):
if num_samp == 0:
continue
print(f'<Hybrid> Querying {num_samp} (={round(100*num_samp/np.sum(self.ks))}%) items using the {subsamp} subsampler')
subsamp.set_available_indices(self.chosen)
if isinstance(subsamp, DATESampling):
self.chosen = [*self.chosen, *subsamp.query(num_samp, model_available = trained_DATE_available)]
self.each_chosen[subsamp] = subsamp.query(num_samp, model_available = trained_DATE_available)
trained_DATE_available = True
else:
self.chosen = [*self.chosen, *subsamp.query(num_samp)]
self.each_chosen[subsamp] = subsamp.query(num_samp)
return self.chosen