diff --git a/checker/feedcache.go b/checker/feedcache.go new file mode 100644 index 0000000..aeaf154 --- /dev/null +++ b/checker/feedcache.go @@ -0,0 +1,89 @@ +package checker + +import ( + "context" + "net/url" + "strings" + "sync" + "time" +) + +// feedCache is a generic URL-feed cache shared between phishing-feed +// sources (OpenPhish, PhishTank). It holds a hostname-indexed snapshot +// of the feed, refreshes on TTL expiry, and ensures only one refresh is +// in flight at a time so concurrent lookups still serve stale data +// during a refresh. +type feedCache struct { + mu sync.Mutex + urls []string + byHost map[string][]string + fetchedAt time.Time + lastAttemptAt time.Time + refreshing bool + ttl time.Duration + failBackoff time.Duration + fetchFn func(ctx context.Context) (urls []string, byHost map[string][]string, err error) +} + +func newFeedCache(ttl time.Duration, fetch func(context.Context) ([]string, map[string][]string, error)) *feedCache { + if ttl <= 0 { + ttl = time.Hour + } + return &feedCache{ + ttl: ttl, + failBackoff: time.Minute, + fetchFn: fetch, + } +} + +func (c *feedCache) setTTL(d time.Duration) { + c.mu.Lock() + c.ttl = d + c.mu.Unlock() +} + +func (c *feedCache) lookup(ctx context.Context, domain string) (urls []string, size int, fetchedAt time.Time, err error) { + domain = strings.ToLower(strings.TrimSuffix(domain, ".")) + + c.mu.Lock() + stale := c.byHost == nil || time.Since(c.fetchedAt) > c.ttl + doRefresh := stale && !c.refreshing && time.Since(c.lastAttemptAt) > c.failBackoff + if doRefresh { + c.refreshing = true + } + c.mu.Unlock() + + if doRefresh { + newURLs, newByHost, ferr := c.fetchFn(ctx) + c.mu.Lock() + c.refreshing = false + c.lastAttemptAt = time.Now() + if ferr == nil { + c.urls = newURLs + c.byHost = newByHost + c.fetchedAt = c.lastAttemptAt + } else { + err = ferr + } + c.mu.Unlock() + } + + c.mu.Lock() + for host, hostURLs := range c.byHost { + if host == domain || strings.HasSuffix(host, "."+domain) { + urls = append(urls, hostURLs...) + } + } + size = len(c.urls) + fetchedAt = c.fetchedAt + c.mu.Unlock() + return urls, size, fetchedAt, err +} + +func hostOfURL(s string) string { + u, err := url.Parse(s) + if err != nil { + return "" + } + return strings.ToLower(u.Hostname()) +} diff --git a/checker/openphish.go b/checker/openphish.go index 12ab8dc..66dc08e 100644 --- a/checker/openphish.go +++ b/checker/openphish.go @@ -6,9 +6,7 @@ import ( "fmt" "io" "net/http" - "net/url" "strings" - "sync" "time" sdk "git.happydns.org/checker-sdk-go/checker" @@ -18,7 +16,7 @@ const openPhishFeedURL = "https://openphish.com/feed.txt" func init() { Register(&openPhishSource{ - cache: newPhishCache(openPhishFeedURL, 1*time.Hour), + cache: newFeedCache(1*time.Hour, openPhishFetch(openPhishFeedURL)), }) } @@ -27,7 +25,7 @@ func init() { // every URL in the feed. The cache is per-source-instance so it lives // for as long as the process. type openPhishSource struct { - cache *phishCache + cache *feedCache } func (*openPhishSource) ID() string { return "openphish" } @@ -96,114 +94,46 @@ func (*openPhishSource) Diagnose(res SourceResult) Diagnosis { } } -// ---------- feed cache ---------- +// openPhishFetch returns a fetchFn that downloads and parses the +// OpenPhish plain-text feed at feedURL. +func openPhishFetch(feedURL string) func(context.Context) ([]string, map[string][]string, error) { + return func(ctx context.Context) ([]string, map[string][]string, error) { + reqCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() -type phishCache struct { - mu sync.Mutex - urls []string - byHost map[string][]string - fetchedAt time.Time - lastAttemptAt time.Time - refreshing bool - ttl time.Duration - failBackoff time.Duration - feedURL string -} - -func newPhishCache(feedURL string, ttl time.Duration) *phishCache { - if feedURL == "" { - feedURL = openPhishFeedURL - } - if ttl <= 0 { - ttl = 1 * time.Hour - } - return &phishCache{ttl: ttl, feedURL: feedURL, failBackoff: 1 * time.Minute} -} - -func (c *phishCache) lookup(ctx context.Context, domain string) (urls []string, size int, fetchedAt time.Time, err error) { - domain = strings.ToLower(strings.TrimSuffix(domain, ".")) - - c.mu.Lock() - stale := c.byHost == nil || time.Since(c.fetchedAt) > c.ttl - doRefresh := stale && !c.refreshing && time.Since(c.lastAttemptAt) > c.failBackoff - if doRefresh { - c.refreshing = true - } - c.mu.Unlock() - - if doRefresh { - // Fetch without holding the cache lock so concurrent lookups - // can still serve stale data. Only one refresh runs at a time. - newURLs, newByHost, ferr := c.fetch(ctx) - c.mu.Lock() - c.refreshing = false - c.lastAttemptAt = time.Now() - if ferr == nil { - c.urls = newURLs - c.byHost = newByHost - c.fetchedAt = c.lastAttemptAt - } else { - err = ferr + req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, feedURL, nil) + if err != nil { + return nil, nil, err } - c.mu.Unlock() - } + req.Header.Set("User-Agent", "happydomain-checker-blacklist/1.0") - c.mu.Lock() - for host, hostURLs := range c.byHost { - if host == domain || strings.HasSuffix(host, "."+domain) { - urls = append(urls, hostURLs...) + resp, err := sharedHTTPClient.Do(req) + if err != nil { + return nil, nil, err } - } - size = len(c.urls) - fetchedAt = c.fetchedAt - c.mu.Unlock() - return urls, size, fetchedAt, err -} + defer resp.Body.Close() -func (c *phishCache) fetch(ctx context.Context) ([]string, map[string][]string, error) { - reqCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, c.feedURL, nil) - if err != nil { - return nil, nil, err - } - req.Header.Set("User-Agent", "happydomain-checker-blacklist/1.0") - - resp, err := sharedHTTPClient.Do(req) - if err != nil { - return nil, nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, nil, fmt.Errorf("openphish HTTP %d", resp.StatusCode) - } - - urls := make([]string, 0, 8192) - byHost := make(map[string][]string, 8192) - scanner := bufio.NewScanner(io.LimitReader(resp.Body, 64<<20)) - scanner.Buffer(make([]byte, 0, 64*1024), 1<<20) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("openphish HTTP %d", resp.StatusCode) } - urls = append(urls, line) - if h := hostOfURL(line); h != "" { - byHost[h] = append(byHost[h], line) - } - } - if err := scanner.Err(); err != nil { - return nil, nil, err - } - return urls, byHost, nil -} -func hostOfURL(s string) string { - u, err := url.Parse(s) - if err != nil { - return "" + urls := make([]string, 0, 8192) + byHost := make(map[string][]string, 8192) + scanner := bufio.NewScanner(io.LimitReader(resp.Body, 64<<20)) + scanner.Buffer(make([]byte, 0, 64*1024), 1<<20) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + urls = append(urls, line) + if h := hostOfURL(line); h != "" { + byHost[h] = append(byHost[h], line) + } + } + if err := scanner.Err(); err != nil { + return nil, nil, err + } + return urls, byHost, nil } - return strings.ToLower(u.Hostname()) } diff --git a/checker/phishtank.go b/checker/phishtank.go index 003fb7b..6208360 100644 --- a/checker/phishtank.go +++ b/checker/phishtank.go @@ -10,7 +10,6 @@ import ( "net/http" "strconv" "strings" - "sync" "time" sdk "git.happydns.org/checker-sdk-go/checker" @@ -21,7 +20,7 @@ const ( phishTankDefaultTTL = 12 * time.Hour ) -var phishTankGlobalCache = newPhishTankCache() +var phishTankGlobalCache = newFeedCache(phishTankDefaultTTL, phishTankFetch) func init() { Register(&phishTankSource{}) } @@ -107,71 +106,8 @@ func (*phishTankSource) Diagnose(res SourceResult) Diagnosis { } } -// ---------- feed cache ---------- - -type phishTankCache struct { - mu sync.Mutex - urls []string - byHost map[string][]string - fetchedAt time.Time - lastAttemptAt time.Time - refreshing bool - ttl time.Duration - failBackoff time.Duration -} - -func newPhishTankCache() *phishTankCache { - return &phishTankCache{ - ttl: phishTankDefaultTTL, - failBackoff: 1 * time.Minute, - } -} - -func (c *phishTankCache) setTTL(d time.Duration) { - c.mu.Lock() - c.ttl = d - c.mu.Unlock() -} - -func (c *phishTankCache) lookup(ctx context.Context, domain string) (urls []string, size int, fetchedAt time.Time, err error) { - domain = strings.ToLower(strings.TrimSuffix(domain, ".")) - - c.mu.Lock() - stale := c.byHost == nil || time.Since(c.fetchedAt) > c.ttl - doRefresh := stale && !c.refreshing && time.Since(c.lastAttemptAt) > c.failBackoff - if doRefresh { - c.refreshing = true - } - c.mu.Unlock() - - if doRefresh { - newURLs, newByHost, ferr := c.fetch(ctx) - c.mu.Lock() - c.refreshing = false - c.lastAttemptAt = time.Now() - if ferr == nil { - c.urls = newURLs - c.byHost = newByHost - c.fetchedAt = c.lastAttemptAt - } else { - err = ferr - } - c.mu.Unlock() - } - - c.mu.Lock() - for host, hostURLs := range c.byHost { - if host == domain || strings.HasSuffix(host, "."+domain) { - urls = append(urls, hostURLs...) - } - } - size = len(c.urls) - fetchedAt = c.fetchedAt - c.mu.Unlock() - return urls, size, fetchedAt, err -} - -func (c *phishTankCache) fetch(ctx context.Context) ([]string, map[string][]string, error) { +// phishTankFetch downloads and parses the PhishTank gzip-compressed CSV feed. +func phishTankFetch(ctx context.Context) ([]string, map[string][]string, error) { reqCtx, cancel := context.WithTimeout(ctx, 120*time.Second) defer cancel()