Skip to content

Commit

Permalink
fix/hw3 (#4)
Browse files Browse the repository at this point in the history
* Make log_entry_matches_packet inline

* Rename send_tcp_packet to send_packet

* Fix typo in parser.c

* Fix bug in handling empty rules files and handle comments in rules
  • Loading branch information
itaispiegel authored Feb 4, 2024
1 parent 20b8529 commit 051af58
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 23 deletions.
4 changes: 2 additions & 2 deletions module/netfilter_hook.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ static inline log_row_t new_log_row_by_packet(packet_t *packet,
};
}

static bool log_entry_matches_packet(struct log_entry *log_entry,
packet_t *packet) {
static inline bool log_entry_matches_packet(struct log_entry *log_entry,
packet_t *packet) {
return log_entry->log_row.protocol == packet->protocol &&
log_entry->log_row.src_ip == packet->src_ip &&
log_entry->log_row.dst_ip == packet->dst_ip &&
Expand Down
2 changes: 1 addition & 1 deletion module/parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void parse_packet(packet_t *packet, struct sk_buff *skb) {

packet->protocol = ip_header->protocol;

// Noteice that we the store the exact ports, even if they're above 1023.
// Notice that we the store the exact ports, even if they're above 1023.
if (packet->protocol == PROT_TCP) {
tcp_header = tcp_hdr(skb);
packet->src_port = tcp_header->source;
Expand Down
15 changes: 10 additions & 5 deletions module/rules_table.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ static struct file_operations fops = {
static ssize_t rules_table_show(struct device *dev,
struct device_attribute *attr, char *buf) {
__u8 i;
__u16 offset = 0;
for (i = 0; i < rules_count; i++) {
memcpy(buf + i * sizeof(rule_t), &rules[i], sizeof(rule_t));
memcpy(buf + offset, &rules[i], sizeof(rule_t));
offset += sizeof(rule_t);
}

return rules_count * sizeof(rule_t);
Expand All @@ -24,12 +26,15 @@ static ssize_t rules_table_store(struct device *dev,
struct device_attribute *attr, const char *buf,
size_t count) {

if (count > MAX_RULES * sizeof(rule_t)) {
// Writing a single NULL byte to the table will reset it.
if (count == 1 && *buf == 0) {
memset(rules, 0, sizeof(rules));
rules_count = 0;
return count;
} else if (count > MAX_RULES * sizeof(rule_t)) {
printk(KERN_WARNING "Can't save rules, since the size is too big\n");
return -EINVAL;
}

if (count % sizeof(rule_t) != 0) {
} else if (count % sizeof(rule_t) != 0) {
printk(KERN_WARNING "Can't save rules, since the size is invalid\n");
return -EINVAL;
}
Expand Down
4 changes: 2 additions & 2 deletions scripts/send_transport_layer_pkt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@click.option(
"--protocol", prompt=True, type=click.Choice(["tcp", "udp"]), default="tcp"
)
def send_tcp_packet(target_ip: str, target_port: int, source_port: int, protocol: str):
def send_packet(target_ip: str, target_port: int, source_port: int, protocol: str):
transport_layer = None
if protocol == "tcp":
transport_layer = TCP(dport=target_port)
Expand All @@ -37,4 +37,4 @@ def send_tcp_packet(target_ip: str, target_port: int, source_port: int, protocol


if __name__ == "__main__":
send_tcp_packet()
send_packet()
9 changes: 9 additions & 0 deletions user/cmd/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (
"github.com/spf13/cobra"
)

const (
commentPrefix = "#" // Lines starting with this prefix are considered comments
)

var loadCmd = &cobra.Command{
Use: "load_rules [rules_file]",
Short: "Load the firewall rules from a given file",
Expand All @@ -25,6 +29,11 @@ func executeLoadRules(cmd *cobra.Command, args []string) error {
}

rulesLines := utils.SplitLines(strings.TrimSpace(string(rulesBytes)))
rulesLines = utils.RemoveLinesWithPrefix(rulesLines, commentPrefix)
if len(rulesLines) == 0 {
return rulestable.ClearTable()
}

newRules := make([]rules.Rule, len(rulesLines))
for i, ruleLine := range rulesLines {
rule, err := rules.ParseRule(ruleLine)
Expand Down
4 changes: 3 additions & 1 deletion user/pkg/rules/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"github.com/itaispiegel/infosec-workshop/user/pkg/fwtypes"
)

var ErrInvalidRuleFormat = errors.New("invalid rule format")

func parseCidr(cidr string) (net.IP, net.IPMask, error) {
if cidr == "any" {
return net.IPv4(0, 0, 0, 0), net.CIDRMask(0, 32), nil
Expand Down Expand Up @@ -38,7 +40,7 @@ func parsePort(port string) (uint16, error) {
func ParseRule(ruleLine string) (*Rule, error) {
fields := strings.Split(ruleLine, " ")
if len(fields) != 9 {
return &Rule{}, errors.New("invalid rule format")
return &Rule{}, ErrInvalidRuleFormat
}

name := fields[0]
Expand Down
31 changes: 20 additions & 11 deletions user/pkg/rulestable/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,12 @@ func SaveRules(rules []rules.Rule) error {
buf = append(buf, rule.Marshal()...)
}

f, err := os.OpenFile(RuleTableDeviceFile, os.O_WRONLY, 0)
if err != nil {
return err
}
defer f.Close()

_, err = f.Write(buf)
if err != nil {
return err
}
return writeToRulesTable(buf)
}

return nil
// Clears the firewall rule table.
func ClearTable() error {
return writeToRulesTable([]byte{0})
}

// Reads the rule table from the firewall rule table device file, and returns it.
Expand All @@ -47,3 +41,18 @@ func ReadRules() ([]rules.Rule, error) {

return table, nil
}

func writeToRulesTable(buf []byte) error {
f, err := os.OpenFile(RuleTableDeviceFile, os.O_WRONLY, 0)
if err != nil {
return err
}
defer f.Close()

_, err = f.Write(buf)
if err != nil {
return err
}

return nil
}
20 changes: 19 additions & 1 deletion user/pkg/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,27 @@ const (
LF = "\n"
)

// Splits a string into lines, and returns a slice of the lines.
// The string can be in either CRLF or LF line endings.
// If the string is empty, an empty slice is returned.
func SplitLines(text string) []string {
normalized := strings.Replace(text, CRLF, LF, -1)
return strings.Split(normalized, LF)
if res := strings.Split(normalized, LF); len(res) == 1 && res[0] == "" {
return []string{}
} else {
return res
}
}

// Removes lines from a slice of lines that start with a given prefix.
func RemoveLinesWithPrefix(lines []string, prefix string) []string {
var res []string
for _, line := range lines {
if !strings.HasPrefix(line, prefix) {
res = append(res, line)
}
}
return res
}

func PanicIfError(err error) {
Expand Down

0 comments on commit 051af58

Please sign in to comment.