diff --git a/module/netfilter_hook.c b/module/netfilter_hook.c index 420e1fb..25640d0 100644 --- a/module/netfilter_hook.c +++ b/module/netfilter_hook.c @@ -95,8 +95,8 @@ static inline log_row_t new_log_row_by_packet(packet_t *packet, }; } -static bool log_entry_matches_packet(struct log_entry *log_entry, - packet_t *packet) { +static inline bool log_entry_matches_packet(struct log_entry *log_entry, + packet_t *packet) { return log_entry->log_row.protocol == packet->protocol && log_entry->log_row.src_ip == packet->src_ip && log_entry->log_row.dst_ip == packet->dst_ip && diff --git a/module/parser.c b/module/parser.c index cf80cee..ce2a6e6 100644 --- a/module/parser.c +++ b/module/parser.c @@ -25,7 +25,7 @@ void parse_packet(packet_t *packet, struct sk_buff *skb) { packet->protocol = ip_header->protocol; - // Noteice that we the store the exact ports, even if they're above 1023. + // Notice that we the store the exact ports, even if they're above 1023. if (packet->protocol == PROT_TCP) { tcp_header = tcp_hdr(skb); packet->src_port = tcp_header->source; diff --git a/module/rules_table.c b/module/rules_table.c index ea41226..eb31942 100644 --- a/module/rules_table.c +++ b/module/rules_table.c @@ -13,8 +13,10 @@ static struct file_operations fops = { static ssize_t rules_table_show(struct device *dev, struct device_attribute *attr, char *buf) { __u8 i; + __u16 offset = 0; for (i = 0; i < rules_count; i++) { - memcpy(buf + i * sizeof(rule_t), &rules[i], sizeof(rule_t)); + memcpy(buf + offset, &rules[i], sizeof(rule_t)); + offset += sizeof(rule_t); } return rules_count * sizeof(rule_t); @@ -24,12 +26,15 @@ static ssize_t rules_table_store(struct device *dev, struct device_attribute *attr, const char *buf, size_t count) { - if (count > MAX_RULES * sizeof(rule_t)) { + // Writing a single NULL byte to the table will reset it. + if (count == 1 && *buf == 0) { + memset(rules, 0, sizeof(rules)); + rules_count = 0; + return count; + } else if (count > MAX_RULES * sizeof(rule_t)) { printk(KERN_WARNING "Can't save rules, since the size is too big\n"); return -EINVAL; - } - - if (count % sizeof(rule_t) != 0) { + } else if (count % sizeof(rule_t) != 0) { printk(KERN_WARNING "Can't save rules, since the size is invalid\n"); return -EINVAL; } diff --git a/scripts/send_transport_layer_pkt.py b/scripts/send_transport_layer_pkt.py index ec681f5..a57f4b7 100644 --- a/scripts/send_transport_layer_pkt.py +++ b/scripts/send_transport_layer_pkt.py @@ -19,7 +19,7 @@ @click.option( "--protocol", prompt=True, type=click.Choice(["tcp", "udp"]), default="tcp" ) -def send_tcp_packet(target_ip: str, target_port: int, source_port: int, protocol: str): +def send_packet(target_ip: str, target_port: int, source_port: int, protocol: str): transport_layer = None if protocol == "tcp": transport_layer = TCP(dport=target_port) @@ -37,4 +37,4 @@ def send_tcp_packet(target_ip: str, target_port: int, source_port: int, protocol if __name__ == "__main__": - send_tcp_packet() + send_packet() diff --git a/user/cmd/load.go b/user/cmd/load.go index 7eab340..83d4670 100644 --- a/user/cmd/load.go +++ b/user/cmd/load.go @@ -10,6 +10,10 @@ import ( "github.com/spf13/cobra" ) +const ( + commentPrefix = "#" // Lines starting with this prefix are considered comments +) + var loadCmd = &cobra.Command{ Use: "load_rules [rules_file]", Short: "Load the firewall rules from a given file", @@ -25,6 +29,11 @@ func executeLoadRules(cmd *cobra.Command, args []string) error { } rulesLines := utils.SplitLines(strings.TrimSpace(string(rulesBytes))) + rulesLines = utils.RemoveLinesWithPrefix(rulesLines, commentPrefix) + if len(rulesLines) == 0 { + return rulestable.ClearTable() + } + newRules := make([]rules.Rule, len(rulesLines)) for i, ruleLine := range rulesLines { rule, err := rules.ParseRule(ruleLine) diff --git a/user/pkg/rules/parse.go b/user/pkg/rules/parse.go index 2bbc0e7..f5a6b5a 100644 --- a/user/pkg/rules/parse.go +++ b/user/pkg/rules/parse.go @@ -9,6 +9,8 @@ import ( "github.com/itaispiegel/infosec-workshop/user/pkg/fwtypes" ) +var ErrInvalidRuleFormat = errors.New("invalid rule format") + func parseCidr(cidr string) (net.IP, net.IPMask, error) { if cidr == "any" { return net.IPv4(0, 0, 0, 0), net.CIDRMask(0, 32), nil @@ -38,7 +40,7 @@ func parsePort(port string) (uint16, error) { func ParseRule(ruleLine string) (*Rule, error) { fields := strings.Split(ruleLine, " ") if len(fields) != 9 { - return &Rule{}, errors.New("invalid rule format") + return &Rule{}, ErrInvalidRuleFormat } name := fields[0] diff --git a/user/pkg/rulestable/table.go b/user/pkg/rulestable/table.go index dec7643..7c1bd07 100644 --- a/user/pkg/rulestable/table.go +++ b/user/pkg/rulestable/table.go @@ -19,18 +19,12 @@ func SaveRules(rules []rules.Rule) error { buf = append(buf, rule.Marshal()...) } - f, err := os.OpenFile(RuleTableDeviceFile, os.O_WRONLY, 0) - if err != nil { - return err - } - defer f.Close() - - _, err = f.Write(buf) - if err != nil { - return err - } + return writeToRulesTable(buf) +} - return nil +// Clears the firewall rule table. +func ClearTable() error { + return writeToRulesTable([]byte{0}) } // Reads the rule table from the firewall rule table device file, and returns it. @@ -47,3 +41,18 @@ func ReadRules() ([]rules.Rule, error) { return table, nil } + +func writeToRulesTable(buf []byte) error { + f, err := os.OpenFile(RuleTableDeviceFile, os.O_WRONLY, 0) + if err != nil { + return err + } + defer f.Close() + + _, err = f.Write(buf) + if err != nil { + return err + } + + return nil +} diff --git a/user/pkg/utils/utils.go b/user/pkg/utils/utils.go index dcda9ca..7f54bf7 100644 --- a/user/pkg/utils/utils.go +++ b/user/pkg/utils/utils.go @@ -9,9 +9,27 @@ const ( LF = "\n" ) +// Splits a string into lines, and returns a slice of the lines. +// The string can be in either CRLF or LF line endings. +// If the string is empty, an empty slice is returned. func SplitLines(text string) []string { normalized := strings.Replace(text, CRLF, LF, -1) - return strings.Split(normalized, LF) + if res := strings.Split(normalized, LF); len(res) == 1 && res[0] == "" { + return []string{} + } else { + return res + } +} + +// Removes lines from a slice of lines that start with a given prefix. +func RemoveLinesWithPrefix(lines []string, prefix string) []string { + var res []string + for _, line := range lines { + if !strings.HasPrefix(line, prefix) { + res = append(res, line) + } + } + return res } func PanicIfError(err error) {