-
Notifications
You must be signed in to change notification settings - Fork 4
/
string_model.h
68 lines (54 loc) · 1.84 KB
/
string_model.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#ifndef STRING_MODEL_H
#define STRING_MODEL_H
#include "model.h"
#include "base.h"
#include "categorical_model.h"
#include <vector>
namespace db_compress {
class StringAttrValue: public AttrValue {
private:
std::string value_;
public:
StringAttrValue() {}
StringAttrValue(const std::string& val) : value_(val) {}
inline void Set(const std::string& val) { value_ = val; }
inline const std::string& Value() const { return value_; }
inline std::string* Pointer() { return &value_; }
};
class StringSquID : public SquID {
private:
const std::vector<Prob> *char_prob_, *len_prob_;
bool is_end_;
int len_;
StringAttrValue attr_;
public:
void Init(const std::vector<Prob>* char_prob, const std::vector<Prob>* len_prob);
bool HasNextBranch() const { return !is_end_; }
void GenerateNextBranch();
int GetNextBranch(const AttrValue* attr) const;
void ChooseNextBranch(int branch);
const AttrValue* GetResultAttr() { return &attr_; }
};
class StringModel : public SquIDModel {
private:
std::vector<Prob> char_prob_, length_prob_;
std::vector<int> char_count_, length_count_;
StringSquID squid_;
public:
StringModel(size_t target_var);
SquID* GetSquID(const Tuple& tuple);
int GetModelCost() const;
void FeedTuple(const Tuple& tuple);
void EndOfData();
int GetModelDescriptionLength() const;
void WriteModel(ByteWriter* byte_writer, size_t block_index) const;
static SquIDModel* ReadModel(ByteReader* byte_reader, size_t index);
};
class StringModelCreator : public ModelCreator {
public:
SquIDModel* ReadModel(ByteReader* byte_reader, const Schema& schema, size_t index);
SquIDModel* CreateModel(const Schema& schema, const std::vector<size_t>& predictor,
size_t index, double err);
};
} // namespace db_compress
#endif