From 3248916e00640204de0caf5167a93f1f9d05df65 Mon Sep 17 00:00:00 2001 From: caffix Date: Thu, 5 Dec 2024 13:37:05 -0500 Subject: [PATCH] moved the performance improvement to the related support routine --- engine/plugins/ip_netblock.go | 100 ++---------------------------- engine/plugins/support/support.go | 86 ++++++++++++++++++++++++- 2 files changed, 90 insertions(+), 96 deletions(-) diff --git a/engine/plugins/ip_netblock.go b/engine/plugins/ip_netblock.go index bbbcb8c5..557c6b75 100644 --- a/engine/plugins/ip_netblock.go +++ b/engine/plugins/ip_netblock.go @@ -9,7 +9,6 @@ import ( "fmt" "log/slog" "net/netip" - "sync" "time" "github.com/owasp-amass/amass/v4/engine/plugins/support" @@ -22,28 +21,13 @@ import ( "github.com/owasp-amass/open-asset-model/relation" ) -type sessnets struct { - last time.Time - nets map[string]*oamnet.Netblock -} - type ipNetblock struct { - name string - log *slog.Logger - done chan struct{} - mlock sync.Mutex - netblocks map[string]*sessnets + name string + log *slog.Logger } func NewIPNetblock() et.Plugin { - p := &ipNetblock{ - name: "IP-Netblock", - done: make(chan struct{}, 2), - netblocks: make(map[string]*sessnets), - } - - go p.checkNetblocks() - return p + return &ipNetblock{name: "IP-Netblock"} } func (d *ipNetblock) Name() string { @@ -71,7 +55,6 @@ func (d *ipNetblock) Start(r et.Registry) error { } func (d *ipNetblock) Stop() { - close(d.done) d.log.Info("Plugin stopped") } @@ -100,13 +83,9 @@ func (d *ipNetblock) lookup(e *et.Event) error { } else { var err error - netblock, err = d.lookupNetblock(e.Session.ID().String(), ip) + netblock, err = support.IPToNetblockWithAttempts(e.Session, ip, 60, time.Second) if err != nil { - netblock, err = support.IPToNetblockWithAttempts(e.Session, ip, 60, time.Second) - if err != nil { - return nil - } - d.addNetblock(e.Session.ID().String(), netblock) + return nil } } @@ -152,72 +131,3 @@ func (d *ipNetblock) reservedAS(e *et.Event, netblock *oamnet.Netblock) { } } } - -func (d *ipNetblock) lookupNetblock(sessid string, ip *oamnet.IPAddress) (*oamnet.Netblock, error) { - d.mlock.Lock() - defer d.mlock.Unlock() - - n, ok := d.netblocks[sessid] - if !ok { - return nil, errors.New("no netblocks found") - } - n.last = time.Now() - - var size int - var found *oamnet.Netblock - for _, nb := range n.nets { - if nb.CIDR.Contains(ip.Address) { - if s := nb.CIDR.Masked().Bits(); s > size { - size = s - found = nb - } - } - } - - if found == nil { - return nil, errors.New("no netblock match") - } - return found, nil -} - -func (d *ipNetblock) addNetblock(sessid string, nb *oamnet.Netblock) { - d.mlock.Lock() - defer d.mlock.Unlock() - - if _, found := d.netblocks[sessid]; !found { - d.netblocks[sessid] = &sessnets{nets: make(map[string]*oamnet.Netblock)} - } - - d.netblocks[sessid].last = time.Now() - d.netblocks[sessid].nets[nb.CIDR.String()] = nb -} - -func (d *ipNetblock) checkNetblocks() { - t := time.NewTicker(10 * time.Minute) - defer t.Stop() - - for { - select { - case <-d.done: - return - case <-t.C: - d.cleanSessionNetblocks() - } - } -} - -func (d *ipNetblock) cleanSessionNetblocks() { - d.mlock.Lock() - defer d.mlock.Unlock() - - var sessids []string - for sessid, n := range d.netblocks { - if time.Since(n.last) > time.Hour { - sessids = append(sessids, sessid) - } - } - - for _, sessid := range sessids { - delete(d.netblocks, sessid) - } -} diff --git a/engine/plugins/support/support.go b/engine/plugins/support/support.go index 21517e4d..40406b04 100644 --- a/engine/plugins/support/support.go +++ b/engine/plugins/support/support.go @@ -12,6 +12,7 @@ import ( "os" "regexp" "strings" + "sync" "time" "github.com/caffix/stringset" @@ -28,12 +29,19 @@ import ( xurls "mvdan.cc/xurls/v2" ) +type sessnets struct { + last time.Time + nets map[string]*oamnet.Netblock +} + type SweepCallback func(d *et.Event, addr *oamnet.IPAddress, src *et.Source) const MaxHandlerInstances int = 100 var done chan struct{} var subre, urlre *regexp.Regexp +var mlock sync.Mutex +var netblocks map[string]*sessnets func init() { done = make(chan struct{}) @@ -43,6 +51,8 @@ func init() { urlre = xurls.Relaxed() subre = regexp.MustCompile(dns.AnySubdomainRegexString()) + netblocks = make(map[string]*sessnets) + go checkNetblocks() postalHost = os.Getenv("POSTAL_SERVER_HOST") postalPort = os.Getenv("POSTAL_SERVER_PORT") @@ -143,9 +153,12 @@ func IPToNetblockWithAttempts(session et.Session, ip *oamnet.IPAddress, num int, } func IPToNetblock(session et.Session, ip *oamnet.IPAddress) (*oamnet.Netblock, error) { + if nb, err := lookupNetblock(session.ID().String(), ip); err == nil { + return nb, nil + } + var size int var found *oamnet.Netblock - if entities, err := session.Cache().FindEntitiesByType(oam.Netblock, session.Cache().StartTime()); err == nil && len(entities) > 0 { for _, entity := range entities { if nb, ok := entity.Asset.(*oamnet.Netblock); ok && nb.CIDR.Contains(ip.Address) { @@ -160,9 +173,80 @@ func IPToNetblock(session et.Session, ip *oamnet.IPAddress) (*oamnet.Netblock, e if found == nil { return nil, errors.New("no netblock match in the cache") } + + addNetblock(session.ID().String(), found) return found, nil } +func lookupNetblock(sessid string, ip *oamnet.IPAddress) (*oamnet.Netblock, error) { + mlock.Lock() + defer mlock.Unlock() + + n, ok := netblocks[sessid] + if !ok { + return nil, errors.New("no netblocks found") + } + n.last = time.Now() + + var size int + var found *oamnet.Netblock + for _, nb := range n.nets { + if nb.CIDR.Contains(ip.Address) { + if s := nb.CIDR.Masked().Bits(); s > size { + size = s + found = nb + } + } + } + + if found == nil { + return nil, errors.New("no netblock match") + } + return found, nil +} + +func addNetblock(sessid string, nb *oamnet.Netblock) { + mlock.Lock() + defer mlock.Unlock() + + if _, found := netblocks[sessid]; !found { + netblocks[sessid] = &sessnets{nets: make(map[string]*oamnet.Netblock)} + } + + netblocks[sessid].last = time.Now() + netblocks[sessid].nets[nb.CIDR.String()] = nb +} + +func checkNetblocks() { + t := time.NewTicker(10 * time.Minute) + defer t.Stop() + + for { + select { + case <-done: + return + case <-t.C: + cleanSessionNetblocks() + } + } +} + +func cleanSessionNetblocks() { + mlock.Lock() + defer mlock.Unlock() + + var sessids []string + for sessid, n := range netblocks { + if time.Since(n.last) > time.Hour { + sessids = append(sessids, sessid) + } + } + + for _, sessid := range sessids { + delete(netblocks, sessid) + } +} + func IPAddressSweep(e *et.Event, addr *oamnet.IPAddress, src *et.Source, size int, callback SweepCallback) { // do not work on an IP address that was processed previously _, err := e.Session.Cache().FindEntityByContent(addr, e.Session.Cache().StartTime())