checker-ssh/checker/collect.go

218 lines
6.3 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// This file is part of the happyDomain (R) project.
// Copyright (c) 2020-2026 happyDomain
// Authors: Pierre-Olivier Mercier, et al.
//
// This program is offered under a commercial and under the AGPL license.
// For commercial licensing, contact us at <contact@happydomain.org>.
//
// For AGPL licensing:
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package checker
import (
"context"
"encoding/json"
"fmt"
"log"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
sdk "git.happydns.org/checker-sdk-go/checker"
happydns "git.happydns.org/happyDomain/model"
"git.happydns.org/happyDomain/services/abstract"
)
// Collect resolves addresses + SSHFP records from the abstract.Server
// service attached to this check, probes every (address, port)
// combination in parallel, and returns a populated SSHData.
func (p *sshProvider) Collect(ctx context.Context, opts sdk.CheckerOptions) (any, error) {
server, err := resolveServer(opts)
if err != nil {
return nil, err
}
timeoutMs := sdk.GetIntOption(opts, OptionProbeTimeoutMs, DefaultProbeTimeoutMs)
if timeoutMs <= 0 {
timeoutMs = DefaultProbeTimeoutMs
}
timeout := time.Duration(timeoutMs) * time.Millisecond
includeAuthProbe := sdk.GetBoolOption(opts, OptionIncludeAuthProbe, true)
ports := parsePorts(optString(opts, OptionPorts, ""))
// Port 22 is always probed.
if !containsUint16(ports, DefaultSSHPort) {
ports = append([]uint16{DefaultSSHPort}, ports...)
}
host, ips := addressesFromServer(server)
if len(ips) == 0 {
return nil, fmt.Errorf("abstract.Server service has no A/AAAA records")
}
sshfp := sshfpFromServer(server)
data := &SSHData{
Domain: host,
SSHFP: sshfp,
CollectedAt: time.Now(),
}
// The fanout is small in practice (at most a handful of IPs × a
// handful of ports), but we still cap concurrency for consistency
// with the TLS checker.
var mu sync.Mutex
var wg sync.WaitGroup
sem := make(chan struct{}, MaxConcurrentProbes)
for _, ip := range ips {
for _, port := range ports {
wg.Add(1)
sem <- struct{}{}
go func(ip string, port uint16) {
defer wg.Done()
defer func() { <-sem }()
probe := probeEndpoint(ctx, host, ip, port, timeout, includeAuthProbe, sshfp)
log.Printf("checker-ssh: %s:%d banner=%q kex=%d hostkeys=%d stage=%s",
ip, port, probe.Banner, len(probe.KEX), len(probe.HostKeys), probe.Stage)
mu.Lock()
data.Endpoints = append(data.Endpoints, probe)
mu.Unlock()
}(ip, port)
}
}
wg.Wait()
return data, nil
}
// resolveServer extracts the *abstract.Server payload from the options.
// Two shapes are supported, same as the ping checker:
// - "service": ServiceMessage (in-process plugin path, or HTTP after
// sdk.GetOption JSON-round-trips).
func resolveServer(opts sdk.CheckerOptions) (*abstract.Server, error) {
svc, ok := sdk.GetOption[happydns.ServiceMessage](opts, OptionService)
if !ok {
return nil, fmt.Errorf("no service in options: did the host wire AutoFillService?")
}
if svc.Type != "abstract.Server" {
return nil, fmt.Errorf("service is %q, expected abstract.Server", svc.Type)
}
var server abstract.Server
if err := json.Unmarshal(svc.Service, &server); err != nil {
return nil, fmt.Errorf("unmarshal abstract.Server: %w", err)
}
return &server, nil
}
// addressesFromServer returns the service's owner domain name (used
// for SNI-like purposes in SSH banner/hostname exchange) and the list
// of IPs to probe.
func addressesFromServer(server *abstract.Server) (host string, ips []string) {
// We can't know the service's owner domain from the Server payload
// alone. The host value we use here is purely informational for
// the report; the ssh handshake itself doesn't need it.
if server.A != nil && len(server.A.A) > 0 {
host = strings.TrimSuffix(server.A.Hdr.Name, ".")
ips = append(ips, server.A.A.String())
}
if server.AAAA != nil && len(server.AAAA.AAAA) > 0 {
if host == "" {
host = strings.TrimSuffix(server.AAAA.Hdr.Name, ".")
}
ips = append(ips, server.AAAA.AAAA.String())
}
return
}
// sshfpFromServer flattens the SSHFP records attached to the service
// into our transport-neutral SSHFPSummary.
func sshfpFromServer(server *abstract.Server) SSHFPSummary {
out := SSHFPSummary{Present: len(server.SSHFP) > 0}
for _, rr := range server.SSHFP {
if rr == nil {
continue
}
out.Records = append(out.Records, SSHFPRecord{
Algorithm: rr.Algorithm,
Type: rr.Type,
Fingerprint: strings.ToLower(rr.FingerPrint),
})
}
return out
}
// Invalid port entries are silently discarded to avoid failing on a bad user input.
func parsePorts(raw string) []uint16 {
if raw == "" {
return nil
}
parts := strings.Split(raw, ",")
var out []uint16
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
n, err := strconv.Atoi(p)
if err != nil || n <= 0 || n > 65535 {
continue
}
u := uint16(n)
if containsUint16(out, u) {
continue
}
out = append(out, u)
}
return out
}
func containsUint16(list []uint16, v uint16) bool {
for _, x := range list {
if x == v {
return true
}
}
return false
}
// optString returns a string option, tolerating json.Number / float64
// sneaking in for what should have been a bare string.
func optString(opts sdk.CheckerOptions, key, def string) string {
v, ok := opts[key]
if !ok {
return def
}
switch s := v.(type) {
case string:
return s
case fmt.Stringer:
return s.String()
}
return def
}
// Used to make golint happy about unused miekg/dns import if we ever
// stop using the abstract.Server.SSHFP path. Currently the import is
// effectively required transitively; kept as a guard.
var _ = dns.TypeSSHFP
// Used to make golint happy about unused net import if we ever stop
// touching IP parsing here.
var _ = net.IPv4len