diff --git a/validator/login.go b/validator/login.go index 1ae9a67..a5339f0 100644 --- a/validator/login.go +++ b/validator/login.go @@ -1,13 +1,30 @@ package main import ( + "crypto/tls" "encoding/json" + "errors" + "fmt" + "io/ioutil" "log" + "net" "net/http" + "os" + "path" + "strings" + "text/template" + + "gopkg.in/ldap.v2" ) -type loginChecker struct{ - students []Student +type loginChecker struct { + students []Student + ldapAddr string + ldapPort int + ldapIsTLS bool + ldapBase string + ldapBindUsername string + ldapBindPassword string } type loginUpload struct { @@ -15,6 +32,65 @@ type loginUpload struct { Password string } +func (l loginChecker) ldapAuth(username, password string) (res bool, err error) { + tlsCnf := tls.Config{InsecureSkipVerify: true} + + var c *ldap.Conn + + if l.ldapIsTLS { + c, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", l.ldapAddr, l.ldapPort), &tlsCnf) + if err != nil { + return false, err + } + } else { + c, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", l.ldapAddr, l.ldapPort)) + if err != nil { + return false, err + } + + // Reconnect with TLS + err = c.StartTLS(&tlsCnf) + if err != nil { + return false, err + } + } + defer c.Close() + + if l.ldapBindUsername != "" { + err = c.Bind(l.ldapBindUsername, l.ldapBindPassword) + if err != nil { + return false, err + } + } + + // Search for the given username + searchRequest := ldap.NewSearchRequest( + l.ldapBase, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + fmt.Sprintf("(&(objectClass=person)(uid=%s))", username), + []string{"dn"}, + nil, + ) + + sr, err := c.Search(searchRequest) + if err != nil { + return false, err + } + + if len(sr.Entries) != 1 { + return false, errors.New("User does not exist or too many entries returned") + } + + userdn := sr.Entries[0].DN + + err = c.Bind(userdn, password) + if err != nil { + return false, err + } + + return true, nil +} + func (l loginChecker) ServeHTTP(w http.ResponseWriter, r *http.Request) { if addr := r.Header.Get("X-Forwarded-For"); addr != "" { r.RemoteAddr = addr @@ -59,6 +135,16 @@ func (l loginChecker) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if ok, err := l.ldapAuth(lu.Username, lu.Password); err != nil { + log.Println("Unable to perform authentication for", lu.Username, ":", err, "at", r.RemoteAddr) + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } else if !ok { + log.Println("Login failed:", lu.Username, "at", r.RemoteAddr) + http.Error(w, "Invalid password", http.StatusUnauthorized) + return + } + if err := l.lateLoginAction(lu.Username, r.RemoteAddr); err != nil { log.Println("Error on late login action:", err) http.Error(w, "Internal server error. Please retry in a few minutes", http.StatusInternalServerError) diff --git a/validator/main.go b/validator/main.go index 44d632c..347b22c 100644 --- a/validator/main.go +++ b/validator/main.go @@ -12,10 +12,19 @@ var tftpDir string func main() { var studentsFile string + var lc loginChecker + var bind = flag.String("bind", ":8081", "Bind port/socket") flag.StringVar(&studentsFile, "students", "./students.csv", "Path to a CSV file containing students list") flag.StringVar(&ARPTable, "arp", ARPTable, "Path to ARP table") flag.StringVar(&tftpDir, "tftpdir", "/var/tftp/", "Path to TFTPd directory") + + flag.StringVar(&lc.ldapAddr, "ldaphost", "auth.cri.epita.fr", "LDAP host") + flag.IntVar(&lc.ldapPort, "ldapport", 636, "LDAP port") + flag.BoolVar(&lc.ldapIsTLS, "ldaptls", false, "Is LDAP connection LDAPS?") + flag.StringVar(&lc.ldapBase, "ldapbase", "dc=epita,dc=net", "LDAP base") + flag.StringVar(&lc.ldapBindUsername, "ldapbindusername", "", "LDAP user to use in order to perform bind (optional if search can be made anonymously)") + flag.StringVar(&lc.ldapBindPassword, "ldapbindpassword", "", "Password for the bind user") flag.Parse() var err error @@ -26,8 +35,7 @@ func main() { log.Fatal(err) } - var students []Student - students, err = readStudentsList(studentsFile) + lc.students, err = readStudentsList(studentsFile) if err != nil { log.Fatal(err) } @@ -35,7 +43,7 @@ func main() { log.Println("Registering handlers...") mux := http.NewServeMux() mux.HandleFunc("/", Index) - mux.Handle("/login", loginChecker{students}) + mux.Handle("/login", lc) http.HandleFunc("/", mux.ServeHTTP) log.Println("Ready, listening on port", *bind)