-
Notifications
You must be signed in to change notification settings - Fork 0
/
cgnat.py
executable file
·246 lines (211 loc) · 8.22 KB
/
cgnat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
#!/usr/bin/env python3
"""
Script: cgnat.py
Author: Viacheslav Hletenko
Date: 2024
Description:
Generate nftables rules for CGNAT
Change external_prefix, internal_prefix and ports_per_user in the main()
Usage: python3 cgnat.py
sudo nft -f cgnat.nft
"""
import ipaddress
def execute_command(command):
"""PoC print only"""
print(command)
class NftablesOperations:
def __init__(
self,
table_name: str = 'cgnat',
chain_name: str = 'POSTROUTING',
hook: str = 'postrouting',
priority: int = 90,
interfaces: str = '',
):
self.table_name = table_name
self.chain_name = chain_name
self.hook = hook
self.priority = priority
self.interfaces = interfaces
self.rules = []
def add_table(self):
execute_command(f'sudo nft add table ip {self.table_name}')
def add_chain(self):
execute_command(
f'sudo nft add chain ip {self.table_name} {self.chain_name} {{ type nat hook {self.hook} priority {self.priority} \; policy accept \; }}'
)
def add_batch_rule(self, rule: str):
self.rules.append(rule)
def generate_batch_file(self) -> str:
inbound_interfaces = ''
if self.interfaces:
inbound_interfaces = f'iifname {{ {self.interfaces} }}'
content = '#!/usr/sbin/nft -f\n\n'
content += f'add table ip {self.table_name}\n'
content += f'flush table ip {self.table_name}\n'
content += f'table ip {self.table_name} {{\n'
content += f' chain {self.chain_name} {{\n'
content += f' type nat hook {self.hook} priority {self.priority}; policy accept;\n'
for rule in self.rules:
content += f' {inbound_interfaces} {rule}\n'
content += f' }}\n'
content += f'}}\n'
return content
def apply_rules(self):
batch_file = self.generate_batch_file()
with open('/tmp/cgnat.nft', 'w') as f:
f.write(batch_file)
execute_command(f'sudo nft -f /tmp/cgnat.nft')
execute_command(f'sudo rm /tmp/cgnat.nft')
class IPOperations:
def __init__(self, ip_prefix: str):
self.ip_prefix = ip_prefix
self.ip_network = ipaddress.ip_network(ip_prefix) if '/' in ip_prefix else None
def get_ips_count(self) -> int:
"""Returns the number of IPs in a prefix or range.
Example:
% ip = IPOperations('192.0.2.0/30')
% ip.get_ips_count()
4
% ip = IPOperations('192.0.2.0-192.0.2.2')
% ip.get_ips_count()
3
"""
if '-' in self.ip_prefix:
start_ip, end_ip = self.ip_prefix.split('-')
start_ip = ipaddress.ip_address(start_ip)
end_ip = ipaddress.ip_address(end_ip)
return int(end_ip) - int(start_ip) + 1
elif '/31' in self.ip_prefix:
return 2
elif '/32' in self.ip_prefix:
return 1
else:
return sum(
1
for _ in [self.ip_network.network_address]
+ list(self.ip_network.hosts())
+ [self.ip_network.broadcast_address]
)
def convert_prefix_to_list_ips(self) -> list:
"""Converts a prefix or IP range to a list of IPs including the network and broadcast addresses.
Example:
% ip = IPOperations('192.0.2.0/30')
% ip.convert_prefix_to_list_ips()
['192.0.2.0', '192.0.2.1', '192.0.2.2', '192.0.2.3']
%
% ip = IPOperations('192.0.0.1-192.0.2.5')
% ip.convert_prefix_to_list_ips()
['192.0.2.1', '192.0.2.2', '192.0.2.3', '192.0.2.4', '192.0.2.5']
"""
if '-' in self.ip_prefix:
start_ip, end_ip = self.ip_prefix.split('-')
start_ip = ipaddress.ip_address(start_ip)
end_ip = ipaddress.ip_address(end_ip)
return [
str(ipaddress.ip_address(ip))
for ip in range(int(start_ip), int(end_ip) + 1)
]
elif '/31' in self.ip_prefix:
return [
str(ip)
for ip in [
self.ip_network.network_address,
self.ip_network.broadcast_address,
]
]
elif '/32' in self.ip_prefix:
return [str(self.ip_network.network_address)]
else:
return [
str(ip)
for ip in [self.ip_network.network_address]
+ list(self.ip_network.hosts())
+ [self.ip_network.broadcast_address]
]
def generate_port_rules(
external_hosts: list,
internal_hosts: list,
port_count: int,
global_port_range: str = '1024-65535',
) -> list:
"""Generates list of nftables rules for the batch file."""
rules = []
start_port, end_port = map(int, global_port_range.split('-'))
total_possible_ports = (end_port - start_port) + 1
error_massage: str = 'Not enough ports available for the specified parameters'
# Calculate the required number of ports per host
required_ports_per_host = port_count
# Check if there are enough external addresses for all internal hosts
if required_ports_per_host * len(internal_hosts) > total_possible_ports * len(
external_hosts
):
raise ValueError(error_massage)
current_port = start_port
current_external_index = 0
for internal_host in internal_hosts:
external_host = external_hosts[current_external_index]
next_end_port = current_port + required_ports_per_host - 1
# If the port range exceeds the end_port, move to the next external host
while next_end_port > end_port:
current_external_index = (current_external_index + 1) % len(external_hosts)
external_host = external_hosts[current_external_index]
current_port = start_port
next_end_port = current_port + required_ports_per_host - 1
# Ensure the same port is not assigned to the same external host
if any(
rule.endswith(f'{external_host}:{current_port}-{next_end_port}')
for rule in rules
):
raise ValueError(error_massage)
for protocol in ('tcp', 'udp'):
rule = f'meta l4proto {protocol} ip saddr {internal_host} counter snat to {external_host}:{current_port}-{next_end_port}'
rules.append(rule)
rules.append(f'ip saddr {internal_host} counter snat to {external_host}')
current_port = next_end_port + 1
if current_port > end_port:
current_port = start_port
current_external_index += 1 # Move to the next external host
return rules
def main():
nft = NftablesOperations()
nft.add_table()
nft.add_chain()
print('---')
# Change the values to required values
# external_prefix = "192.0.2.1-192.0.2.5"
external_prefix: str = '192.0.2.0/30'
internal_prefix: str = '100.64.0.0/28'
ports_per_user: int = 8000
global_port_range: str = '1024-65535'
output_filename: str = 'cgnat.nft'
# Not implemented, use ports_per_user
# tcp_ports = 1024
# udp_ports = 1024
# icmp_ids = 1024
external_count = IPOperations(external_prefix).get_ips_count()
internal_count = IPOperations(internal_prefix).get_ips_count()
external_hosts = IPOperations(external_prefix).convert_prefix_to_list_ips()
internal_hosts = IPOperations(internal_prefix).convert_prefix_to_list_ips()
print('external hosts count:', external_count)
# print('external hosts list:', external_hosts)
print('internal hosts count', internal_count)
# print('internal hosts list', internal_hosts)
print('global port range:', global_port_range)
print('ports per host count:', ports_per_user)
print('---')
try:
rules = generate_port_rules(
external_hosts, internal_hosts, ports_per_user, global_port_range
)
for rule in rules:
nft.add_batch_rule(rule)
# print(nft.generate_batch_file())
# Write rules to the file
with open(output_filename, 'w') as file:
file.write(nft.generate_batch_file())
print(f'To apply rules use: nft -f {output_filename}\n')
except ValueError as e:
print(e)
if __name__ == '__main__':
main()