package main import ( "bytes" "crypto/hmac" "encoding/json" "encoding/hex" "errors" "fmt" "io" "log" "net/http" "os" "os/exec" "path" "strconv" "strings" "time" "github.com/julienschmidt/httprouter" ) var AuthorizedKeysLocation = "/root/.ssh/authorized_keys" var SshPiperLocation = "/var/sshpiper/" func init() { router.GET("/sshkeys", apiHandler( func(httprouter.Params, []byte) (interface{}, error) { return getStudentKeys() })) router.POST("/sshkeys", rawHandler(receiveKey)) router.GET("/sshkeys/authorizedkeys", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { dumpAuthorizedKeysFile(w) }) router.GET("/api/students/:sid/hassshkeys", apiHandler(studentHandler(hasSSHKeys))) router.GET("/api/students/:sid/authorizedkeys", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { if sid, err := strconv.Atoi(string(ps.ByName("sid"))); err != nil { if student, err := getStudentByLogin(ps.ByName("sid")); err != nil { http.Error(w, "Student doesn't exist.", http.StatusNotFound) } else { student.dumpAuthorizedKeysFile(w) } } else if student, err := getStudent(sid); err != nil { http.Error(w, "Student doesn't exist.", http.StatusNotFound) } else { student.dumpAuthorizedKeysFile(w) } }) } func hasSSHKeys(student Student, body []byte) (interface{}, error) { if keys, err := student.getKeys(); err != nil { return nil, err } else { return len(keys) > 0, nil } } type StudentKey struct { Id int64 `json:"id"` IdStudent int64 `json:"id_student"` Key string `json:"key"` Time time.Time `json:"time"` } func getStudentKeys() (keys []StudentKey, err error) { if rows, errr := DBQuery("SELECT id_key, id_student, sshkey, time FROM student_keys"); errr != nil { return nil, errr } else { defer rows.Close() for rows.Next() { var k StudentKey if err = rows.Scan(&k.Id, &k.IdStudent, &k.Key, &k.Time); err != nil { return } keys = append(keys, k) } if err = rows.Err(); err != nil { return } return } } func (s Student) getKeys() (keys []StudentKey, err error) { if rows, errr := DBQuery("SELECT id_key, id_student, sshkey, time FROM student_keys WHERE id_student = ?", s.Id); errr != nil { return nil, errr } else { defer rows.Close() for rows.Next() { var k StudentKey if err = rows.Scan(&k.Id, &k.IdStudent, &k.Key, &k.Time); err != nil { return } keys = append(keys, k) } if err = rows.Err(); err != nil { return } return } } func getStudentKey(id int) (k StudentKey, err error) { err = DBQueryRow("SELECT id_key, id_student, sshkey, time FROM student_keys WHERE id_key=?", id).Scan(&k.Id, &k.IdStudent, &k.Key, &k.Time) return } func (s Student) NewKey(key string) (k StudentKey, err error) { // Check key before importing it cmd := exec.Command("ssh-keygen", "-l", "-f", "-") cmd.Stdin = strings.NewReader(key) var stdoutStderr []byte stdoutStderr, err = cmd.CombinedOutput() if err != nil { if _, ok := err.(*exec.ExitError); ok { err = errors.New(string(stdoutStderr)) } return } chunks := bytes.Fields(stdoutStderr) keytype := string(chunks[len(chunks)-1]) minkeysize := 2048 if keytype == "(ED25519)" || keytype == "(ECDSA)" { minkeysize = 256 } var bits int if bits, err = strconv.Atoi(string(chunks[0])); err != nil { return } else if bits < minkeysize { err = errors.New("Keysize too small") return } // Sanitize the given key keyf := strings.Fields(key) if len(keyf) < 2 { err = errors.New("Unexpected key file, this should never happen") return } key = keyf[0] + " " + keyf[1] if res, err := DBExec("INSERT INTO student_keys (id_student, sshkey, time) VALUES (?, ?, ?)", s.Id, key, time.Now()); err != nil { return StudentKey{}, err } else if kid, err := res.LastInsertId(); err != nil { return StudentKey{}, err } else { s.UnlockNewChallenge(len(challenges), "") return StudentKey{kid, s.Id, key, time.Now()}, nil } } func (k StudentKey) GetStudent() (Student, error) { return getStudent(int(k.IdStudent)) } func (k StudentKey) Update() (int64, error) { if res, err := DBExec("UPDATE student_keys SET id_student = ?, sshkey = ?, time = ? WHERE id_key = ?", k.IdStudent, k.Key, k.Time, k.Id); err != nil { return 0, err } else if nb, err := res.RowsAffected(); err != nil { return 0, err } else { return nb, err } } func (k StudentKey) Delete() (int64, error) { if res, err := DBExec("DELETE FROM student_keys WHERE id_key = ?", k.Id); err != nil { return 0, err } else if nb, err := res.RowsAffected(); err != nil { return 0, err } else { return nb, err } } func receiveKey(r *http.Request, ps httprouter.Params, body []byte) (interface{}, error) { var gt givenToken if err := json.Unmarshal(body, >); err != nil { return nil, err } gt.token = make([]byte, hex.DecodedLen(len(gt.Token))) if _, err := hex.Decode(gt.token, []byte(gt.Token)); err != nil { return nil, err } if std, err := getStudentByLogin(gt.Login); err != nil { return nil, err } else if len(gt.Data) < 2 { return nil, errors.New("No key found!") } else { pkey := std.GetPKey() data := [][]byte{} for _, d := range gt.Data { data = append(data, []byte(d)) } if expectedToken, err := GenerateToken(pkey, 0, data...); err != nil { return nil, err } else if ! hmac.Equal(expectedToken, gt.token) { return nil, errors.New("This is not the expected token.") } if _, err := std.NewKey(gt.Data[0] + " " + gt.Data[1]); err != nil { return nil, err } log.Printf("%s just pushed sshkey\n", std.Login) if len(AuthorizedKeysLocation) > 0 { file, err := os.Create(AuthorizedKeysLocation) if err != nil { log.Fatal("Cannot create file", err) goto sshpiperimport } defer file.Close() dumpAuthorizedKeysFile(file) } sshpiperimport: if len(SshPiperLocation) > 0 { if err := os.MkdirAll(path.Join(SshPiperLocation, std.Login), 0777); err != nil { log.Fatal("Cannot create sshpiper directory:", err) } else { file, err := os.Create(path.Join(SshPiperLocation, std.Login, "authorized_keys")) if err != nil { log.Fatal("Cannot create sshpiperd file", err) goto onerr } defer file.Close() std.dumpAuthorizedKeysFile(file) os.Symlink(path.Join(SshPiperLocation, "sshpiper_upstream"), path.Join(SshPiperLocation, std.Login, "sshpiper_upstream")) os.Symlink(path.Join(SshPiperLocation, "id_rsa"), path.Join(SshPiperLocation, std.Login, "id_rsa")) } } onerr: return "Key imported", nil } } func dumpAuthorizedKeysFile(w io.Writer) { seen := map[string]interface{}{} if keys, _ := getStudentKeys(); keys != nil { for _, k := range keys { if _, exists := seen[k.Key]; exists { continue } else { seen[k.Key] = true } s, _ := k.GetStudent() w.Write([]byte("command=\"/root/adlin.sh " + fmt.Sprintf("%d", k.IdStudent) + " '" + s.Login + "'\",restrict " + k.Key + fmt.Sprintf(" Student#%d-%q\n", k.IdStudent, s.Login))) } } } func (s Student) dumpAuthorizedKeysFile(w io.Writer) { seen := map[string]interface{}{} if keys, _ := s.getKeys(); keys != nil { for _, k := range keys { if _, exists := seen[k.Key]; exists { continue } else { seen[k.Key] = true } s, _ := k.GetStudent() w.Write([]byte(k.Key + fmt.Sprintf(" Student#%d-%q\n", k.IdStudent, s.Login))) } } }