token-validator: fix ssh part and add support for ssh-piperd

This commit is contained in:
nemunaire 2019-02-26 16:48:55 +01:00
parent 20749da348
commit ff9c6bacdf
2 changed files with 79 additions and 5 deletions

View File

@ -58,7 +58,8 @@ func main() {
var dsn = flag.String("dsn", DSNGenerator(), "DSN to connect to the MySQL server") var dsn = flag.String("dsn", DSNGenerator(), "DSN to connect to the MySQL server")
var baseURL = flag.String("baseurl", "/", "URL prepended to each URL") var baseURL = flag.String("baseurl", "/", "URL prepended to each URL")
flag.StringVar(&sharedSecret, "sharedsecret", "adelina", "secret used to communicate with remote validator") flag.StringVar(&sharedSecret, "sharedsecret", "adelina", "secret used to communicate with remote validator")
flag.StringVar(&AuthorizedKeyLocation, "authorizedkeyslocation", Authorizedkeyslocation, "File for allowing user to SSH to the machine") flag.StringVar(&AuthorizedKeysLocation, "authorizedkeyslocation", AuthorizedKeysLocation, "File for allowing user to SSH to the machine")
flag.StringVar(&SshPiperLocation, "sshPiperLocation", SshPiperLocation, "Directory containing directories for sshpiperd")
flag.Parse() flag.Parse()
// Sanitize options // Sanitize options

View File

@ -10,12 +10,15 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"path"
"strconv"
"time" "time"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
) )
var AuthorizedKeyLocation = "/var/lib/adlin/.ssh/authorized_keys" var AuthorizedKeysLocation = "/var/lib/adlin/.ssh/authorized_keys"
var SshPiperLocation = "/var/sshpiper/"
func init() { func init() {
router.GET("/sshkeys/", apiHandler( router.GET("/sshkeys/", apiHandler(
@ -23,9 +26,23 @@ func init() {
return getStudentKeys() return getStudentKeys()
})) }))
router.POST("/sshkeys/", rawHandler(receiveKey)) router.POST("/sshkeys/", rawHandler(receiveKey))
router.GET("/sshkeys/authorizedkey", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { router.GET("/sshkeys/authorizedkeys", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
dumpAuthorizedKeysFile(w) dumpAuthorizedKeysFile(w)
}) })
router.GET("/api/students/:sid/authorizedkeys", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if sid, err := strconv.Atoi(string(ps.ByName("sid"))); err != nil {
if student, err := getStudentByLogin(ps.ByName("sid")); err != nil {
http.Error(w, "Student doesn't exist.", http.StatusNotFound)
} else {
student.dumpAuthorizedKeysFile(w)
}
} else if student, err := getStudent(sid); err != nil {
http.Error(w, "Student doesn't exist.", http.StatusNotFound)
} else {
student.dumpAuthorizedKeysFile(w)
}
})
} }
type StudentKey struct { type StudentKey struct {
@ -56,6 +73,27 @@ func getStudentKeys() (keys []StudentKey, err error) {
} }
} }
func (s Student) getKeys() (keys []StudentKey, err error) {
if rows, errr := DBQuery("SELECT id_key, id_student, sshkey, time FROM student_keys WHERE id_student = ?", s.Id); errr != nil {
return nil, errr
} else {
defer rows.Close()
for rows.Next() {
var k StudentKey
if err = rows.Scan(&k.Id, &k.IdStudent, &k.Key, &k.Time); err != nil {
return
}
keys = append(keys, k)
}
if err = rows.Err(); err != nil {
return
}
return
}
}
func getStudentKey(id int) (k StudentKey, err error) { func getStudentKey(id int) (k StudentKey, err error) {
err = DBQueryRow("SELECT id_key, id_student, sshkey, time FROM student_keys WHERE id_key=?", id).Scan(&k.Id, &k.IdStudent, &k.Key, &k.Time) err = DBQueryRow("SELECT id_key, id_student, sshkey, time FROM student_keys WHERE id_key=?", id).Scan(&k.Id, &k.IdStudent, &k.Key, &k.Time)
return return
@ -130,16 +168,34 @@ func receiveKey(r *http.Request, ps httprouter.Params, body []byte) (interface{}
log.Printf("%s just pushed sshkey\n", std.Login) log.Printf("%s just pushed sshkey\n", std.Login)
if len(AuthorizedKeyLocation) > 0 { if len(AuthorizedKeysLocation) > 0 {
file, err := os.Create(AuthorizedKeyLocation) file, err := os.Create(AuthorizedKeysLocation)
if err != nil { if err != nil {
log.Fatal("Cannot create file", err) log.Fatal("Cannot create file", err)
goto sshpiperimport
} }
defer file.Close() defer file.Close()
dumpAuthorizedKeysFile(file) dumpAuthorizedKeysFile(file)
} }
sshpiperimport:
if len(SshPiperLocation) > 0 {
if err := os.MkdirAll(path.Join(SshPiperLocation, std.Login), 0777); err != nil {
log.Fatal("Cannot create sshpiper directory:", err)
} else {
file, err := os.Create(path.Join(SshPiperLocation, std.Login, "authorized_keys"))
if err != nil {
log.Fatal("Cannot create sshpiperd file", err)
goto onerr
}
defer file.Close()
std.dumpAuthorizedKeysFile(file)
}
}
onerr:
return "Key imported", nil return "Key imported", nil
} }
} }
@ -160,3 +216,20 @@ func dumpAuthorizedKeysFile(w io.Writer) {
} }
} }
} }
func (s Student) dumpAuthorizedKeysFile(w io.Writer) {
seen := map[string]interface{}{}
if keys, _ := s.getKeys(); keys != nil {
for _, k := range keys {
if _, exists := seen[k.Key]; exists {
continue
} else {
seen[k.Key] = true
}
s, _ := k.GetStudent()
w.Write([]byte("ssh-ed25519 " + k.Key + fmt.Sprintf(" Student#%d-%q\n", k.IdStudent, s.Login)))
}
}
}