about summary refs log tree commit diff
path: root/src/user.go
diff options
context:
space:
mode:
authorEmile <git@emile.space>2024-08-16 19:50:26 +0200
committerEmile <git@emile.space>2024-08-16 19:50:26 +0200
commit1a57267a17c2fc17fb6e104846fabc3e363c326c (patch)
tree1e574e3a80622086dc3c81ff9cba65ef7049b1a9 /src/user.go
initial commit
Diffstat (limited to 'src/user.go')
-rw-r--r--src/user.go666
1 files changed, 666 insertions, 0 deletions
diff --git a/src/user.go b/src/user.go
new file mode 100644
index 0000000..cc77657
--- /dev/null
+++ b/src/user.go
@@ -0,0 +1,666 @@
+package main
+
+import (
+	"fmt"
+	"html/template"
+	"log"
+	"net/http"
+	"strconv"
+	"time"
+
+	"github.com/gorilla/mux"
+	"golang.org/x/crypto/argon2"
+)
+
+type User struct {
+	ID           int
+	Name         string
+	PasswordHash []byte
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// GENERAL PURPOSE
+
+func UserRegister(username string, passwordHash []byte) (int, error) {
+	id, err := globalState.InsertUser(User{Name: username, PasswordHash: passwordHash})
+	if err != nil {
+		return 0, err
+	} else {
+		return id, nil
+	}
+}
+
+// UserCheckPasswordHash returns a boolean that is true if the users password
+// is correct and false if the users password is false and thus doesn't match
+// the one stored in the database
+func UserCheckPasswordHash(username string, passwordHash []byte) bool {
+	return globalState.CheckUserHash(username, passwordHash)
+}
+
+// UserUpdatePasswordHash does exactly that
+func UserUpdatePasswordHash(orig_username string, passwordHash []byte) error {
+	return globalState.UpdateUserPasswordHash(orig_username, passwordHash)
+}
+
+func UserUpdateUsername(id int, new_username string) error {
+	return globalState.UpdateUserUsername(id, new_username)
+}
+
+func UserLinkBot(username string, botid int) error {
+	return globalState.LinkUserBot(username, botid)
+}
+
+func UserGetBotsUsingUsername(username string) ([]Bot, error) {
+	return globalState.GetUserBotsUsername(username)
+}
+
+func UserGetBotsUsingUserID(userid int) ([]Bot, error) {
+	return globalState.GetUserBotsId(userid)
+}
+
+func UserGetUserFromID(userid int) (User, error) {
+	return globalState.GetUserFromId(userid)
+}
+
+func UserGetUserFromUsername(username string) (User, error) {
+	return globalState.GetUserFromUsername(username)
+}
+
+func UserGetAll() ([]User, error) {
+	return globalState.GetAllUsers()
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// DATABASE
+
+func (s *State) InsertUser(user User) (int, error) {
+	res, err := s.db.Exec("INSERT INTO users VALUES(NULL,?,?,?);", time.Now(), user.Name, user.PasswordHash)
+	if err != nil {
+		return 0, err
+	}
+
+	var id int64
+	if id, err = res.LastInsertId(); err != nil {
+		return 0, err
+	}
+	return int(id), nil
+}
+
+// returns true if the password matches
+func (s *State) CheckUserHash(username string, passwordHash []byte) bool {
+	var created time.Time
+	err := s.db.QueryRow("SELECT created_at FROM users WHERE name=? AND passwordHash=?", username, passwordHash).Scan(&created)
+	switch {
+	case err != nil:
+		return false
+	default:
+		return true
+	}
+}
+
+func (s *State) UpdateUserPasswordHash(username string, passwordHash []byte) error {
+	_, err := s.db.Exec("UPDATE users SET passwordHash=? WHERE name=?", passwordHash, username)
+	if err != nil {
+		return err
+	} else {
+		return nil
+	}
+}
+
+func (s *State) UpdateUserUsername(id int, new_username string) error {
+	_, err := s.db.Exec("UPDATE users SET name=? WHERE id=?", new_username, id)
+	if err != nil {
+		return err
+	} else {
+		return nil
+	}
+}
+
+// Links the given bot to the given user in the user_bot_rel table
+func (s *State) LinkUserBot(username string, botid int) error {
+	_, err := s.db.Exec("INSERT INTO user_bot_rel VALUES ((SELECT id FROM users WHERE name=?), ?)", username, botid)
+	if err != nil {
+		return err
+	} else {
+		return nil
+	}
+}
+
+// Links the given bot to the given user in the user_bot_rel table
+func (s *State) GetUserFromId(id int) (User, error) {
+	var user_id int
+	var username string
+	err := s.db.QueryRow("SELECT id, name FROM users WHERE id=?", id).Scan(&user_id, &username)
+	if err != nil {
+		return User{}, err
+	} else {
+		return User{user_id, username, nil}, nil
+	}
+}
+
+func (s *State) GetUserFromUsername(username string) (User, error) {
+	var id int
+	var name string
+	err := s.db.QueryRow("SELECT id, name FROM users WHERE name=?", username).Scan(&id, &name)
+	if err != nil {
+		return User{}, err
+	} else {
+		return User{id, name, nil}, nil
+	}
+}
+
+// Returns the bots belonging to the given user
+func (s *State) GetUserBotsUsername(username string) ([]Bot, error) {
+	rows, err := s.db.Query("SELECT id, name, source FROM bots b LEFT JOIN user_bot_rel ub ON ub.bot_id = b.id WHERE ub.user_id=(SELECT id FROM users WHERE name=?)", username)
+	defer rows.Close()
+	if err != nil {
+		return nil, err
+	}
+
+	var bots []Bot
+	for rows.Next() {
+		var bot Bot
+		if err := rows.Scan(&bot.ID, &bot.Name, &bot.Source); err != nil {
+			return bots, err
+		}
+		bots = append(bots, bot)
+	}
+	if err = rows.Err(); err != nil {
+		return bots, err
+	}
+	return bots, nil
+}
+
+// Returns the bots belonging to the given user
+func (s *State) GetUserBotsId(id int) ([]Bot, error) {
+	rows, err := s.db.Query("SELECT id, name, source FROM bots b LEFT JOIN user_bot_rel ub ON ub.bot_id = b.id WHERE ub.user_id=?", id)
+	defer rows.Close()
+	if err != nil {
+		return nil, err
+	}
+
+	var bots []Bot
+	for rows.Next() {
+		var bot Bot
+		if err := rows.Scan(&bot.ID, &bot.Name, &bot.Source); err != nil {
+			return bots, err
+		}
+		bots = append(bots, bot)
+	}
+	if err = rows.Err(); err != nil {
+		return bots, err
+	}
+	return bots, nil
+}
+
+// Returns the bots belonging to the given user
+func (s *State) GetAllUsers() ([]User, error) {
+	rows, err := s.db.Query("SELECT id, name FROM users")
+	defer rows.Close()
+	if err != nil {
+		return nil, err
+	}
+
+	var users []User
+	for rows.Next() {
+		var user User
+		if err := rows.Scan(&user.ID, &user.Name); err != nil {
+			return users, err
+		}
+		users = append(users, user)
+	}
+	if err = rows.Err(); err != nil {
+		return users, err
+	}
+	return users, nil
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// HTTP
+
+func loginHandler(w http.ResponseWriter, r *http.Request) {
+	switch r.Method {
+	case "GET":
+		log.Println("GET /login")
+		// define data
+		log.Println("[d] Defining breadcrumbs")
+		data := map[string]interface{}{}
+		data["pagelink1"] = Link{"login", "/login"}
+		data["pagelink1options"] = []Link{
+			{Name: "register", Target: "/register"},
+		}
+		data["pagelinkauth"] = []Link{
+			{Name: "register/", Target: "/register"},
+		}
+
+		// session foo
+		log.Println("[d] Getting session")
+		session, _ := globalState.sessions.Get(r, "session")
+		username := session.Values["username"]
+
+		// get the user
+		if username != nil {
+			log.Printf("[d] Getting the user %s\n", username.(string))
+			user, err := UserGetUserFromUsername(username.(string))
+			if user.Name == "" {
+				log.Println("no user found")
+			} else if err != nil {
+				log.Println(err)
+				msg := "Error: could not get the user for given username"
+				http.Redirect(w, r, fmt.Sprintf("/login?res=%s", msg), http.StatusSeeOther)
+				return
+			} else {
+				data["user"] = user
+			}
+		}
+
+		// display errors passed via query parameters
+		log.Println("[d] Getting previous results")
+		queryres := r.URL.Query().Get("res")
+		if queryres != "" {
+			data["res"] = queryres
+		}
+
+		// get the template
+		log.Println("[d] Getting the template")
+		t, err := template.ParseGlob("./templates/*.html")
+		if err != nil {
+			log.Println("Error parsing the login template: ", err)
+			w.WriteHeader(http.StatusInternalServerError)
+			w.Write([]byte("500 - Error reading template file"))
+			http.Error(w, err.Error(), http.StatusInternalServerError)
+			return
+		}
+
+		// exec!
+		log.Println("[d] Executing the template")
+		t.ExecuteTemplate(w, "login", data)
+
+	case "POST":
+		// parse the post parameters
+		r.ParseForm()
+		username := r.Form.Get("username")
+		password := r.Form.Get("password")
+
+		// if we've got a password, hash it and compare it with the stored one
+		if password != "" {
+			passwordHash := argon2.IDKey([]byte(password), []byte(salt), 1, 64*1024, 4, 32)
+
+			// check if it's valid
+			valid := UserCheckPasswordHash(username, passwordHash)
+			if valid {
+
+				// if it's valid, we set a session for the user
+				session, _ := globalState.sessions.Get(r, "session")
+				session.Values["username"] = username
+				err := session.Save(r, w)
+				if err != nil {
+					http.Error(w, err.Error(), http.StatusInternalServerError)
+					return
+				}
+
+				http.Redirect(w, r, "/", http.StatusSeeOther)
+				return
+
+			} else {
+				// invalid password
+				http.Redirect(w, r, "/login?err=Invalid+Password", http.StatusSeeOther)
+				return
+			}
+		} else {
+			// empty password
+			http.Redirect(w, r, "/login?err=Empty+Password", http.StatusSeeOther)
+			return
+		}
+	default:
+		http.Redirect(w, r, "/", http.StatusMethodNotAllowed)
+	}
+}
+
+func registerHandler(w http.ResponseWriter, r *http.Request) {
+	switch r.Method {
+	case "GET":
+		// define data
+		data := map[string]interface{}{}
+
+		// get the session
+		session, _ := globalState.sessions.Get(r, "session")
+		username := session.Values["username"]
+		data["pagelink1"] = Link{"register", "/register"}
+		data["pagelink1options"] = []Link{
+			{Name: "login", Target: "/login"},
+		}
+		data["pagelinkauth"] = []Link{
+			{Name: "login/", Target: "/login"},
+		}
+
+		log.Println(username)
+
+		if username != nil {
+			data["logged_in"] = true
+		}
+
+		// get the template
+		t, err := template.ParseGlob("./templates/*.html")
+		if err != nil {
+			w.WriteHeader(http.StatusInternalServerError)
+			w.Write([]byte("500 - Error reading template file"))
+			http.Error(w, err.Error(), http.StatusInternalServerError)
+			return
+		}
+
+		// exec!
+		t.ExecuteTemplate(w, "register", data)
+
+	case "POST":
+		// parse the post parameters
+		r.ParseForm()
+		username := r.Form.Get("username")
+		password1 := r.Form.Get("password1")
+		password2 := r.Form.Get("password2")
+
+		if len(username) >= 64 {
+			w.WriteHeader(http.StatusInternalServerError)
+			w.Write([]byte("500 - Oi', Backend here! Please enter less than 64 chars!"))
+			return
+		}
+
+		if len(password1) >= 256 {
+			w.WriteHeader(http.StatusInternalServerError)
+			w.Write([]byte("500 - Oi', Backend here! Don't overdo with the length please!"))
+			return
+		}
+
+		if password1 != password2 {
+			w.WriteHeader(http.StatusInternalServerError)
+			w.Write([]byte("500 - Oi', Backend here! The passwords you entered don't match!"))
+			return
+		}
+
+		// if we've got a password, hash it and store it and create a User
+		if password1 != "" {
+			passwordHash := argon2.IDKey([]byte(password1), []byte(salt), 1, 64*1024, 4, 32)
+
+			_, err := UserRegister(username, passwordHash)
+			if err != nil {
+				w.WriteHeader(http.StatusInternalServerError)
+				w.Write([]byte("500 - We had problems inserting you into the DB"))
+				return
+			}
+
+			http.Redirect(w, r, "/login", http.StatusSeeOther)
+			return
+		}
+	default:
+		http.Redirect(w, r, "/", http.StatusMethodNotAllowed)
+	}
+}
+
+func logoutHandler(w http.ResponseWriter, r *http.Request) {
+	switch r.Method {
+	case "POST":
+		session, _ := globalState.sessions.Get(r, "session")
+		session.Options.MaxAge = -1
+		err := session.Save(r, w)
+		if err != nil {
+			http.Error(w, err.Error(), http.StatusInternalServerError)
+		}
+		http.Redirect(w, r, "/", http.StatusSeeOther)
+	default:
+		http.Redirect(w, r, "/", http.StatusMethodNotAllowed)
+	}
+}
+
+func userHandler(w http.ResponseWriter, r *http.Request) {
+	vars := mux.Vars(r)
+	id, err := strconv.Atoi(vars["id"])
+	if err != nil {
+		http.Redirect(w, r, "/user", http.StatusSeeOther)
+	}
+
+	switch r.Method {
+	case "GET":
+		// define data
+		data := map[string]interface{}{}
+		data["pagelink1"] = Link{"user", "/user"}
+		data["pagelink1options"] = []Link{
+			{Name: "bot", Target: "/bot"},
+			{Name: "battle", Target: "/battle"},
+		}
+
+		// session foo
+		session, _ := globalState.sessions.Get(r, "session")
+		username := session.Values["username"].(string)
+
+		// the the user making the request
+		user, err := UserGetUserFromUsername(username)
+		if err != nil {
+			http.Redirect(w, r, "/login", http.StatusSeeOther)
+			return
+		} else {
+			data["user"] = user
+		}
+
+		// get the target user using the provided id
+		targetUser, err := UserGetUserFromID(id)
+		if err != nil {
+			data["err"] = "Could not find that user"
+		} else {
+			data["targetUser"] = targetUser
+		}
+
+		// define the breadcrumbs
+		data["pagelink2"] = Link{targetUser.Name, fmt.Sprintf("/%s", targetUser.Name)}
+
+		allUserNames, err := UserGetAll()
+		var opts []Link
+		for _, user := range allUserNames {
+			opts = append(opts, Link{Name: user.Name, Target: fmt.Sprintf("/%d", user.ID)})
+		}
+		data["pagelink2options"] = opts
+
+		// get the bots for the given user
+		bots, err := UserGetBotsUsingUserID(id)
+		if err != nil {
+			http.Redirect(w, r, "/user", http.StatusSeeOther)
+		} else {
+			data["bots"] = bots
+		}
+
+		// get the template
+		t, err := template.ParseGlob("./templates/*.html")
+		if err != nil {
+			w.WriteHeader(http.StatusInternalServerError)
+			w.Write([]byte("500 - Error reading template file"))
+			http.Error(w, err.Error(), http.StatusInternalServerError)
+			return
+		}
+
+		// exec!
+		t.ExecuteTemplate(w, "user", data)
+	default:
+		http.Redirect(w, r, "/", http.StatusMethodNotAllowed)
+	}
+}
+
+func usersHandler(w http.ResponseWriter, r *http.Request) {
+	switch r.Method {
+	case "GET":
+		// define data
+		data := map[string]interface{}{}
+		data["pagelink1"] = Link{Name: "user", Target: "/user"}
+		data["pagelink1options"] = []Link{
+			{Name: "bot", Target: "/bot"},
+			{Name: "battle", Target: "/battle"},
+		}
+
+		// sessions
+		session, _ := globalState.sessions.Get(r, "session")
+		username := session.Values["username"].(string)
+
+		// get the user
+		user, err := UserGetUserFromUsername(username)
+		if err != nil {
+			http.Redirect(w, r, "/login", http.StatusSeeOther)
+			return
+		} else {
+			data["user"] = user
+		}
+
+		// get all users
+		users, err := UserGetAll()
+		data["users"] = users
+
+		// get the template
+		t, err := template.ParseGlob("./templates/*.html")
+		if err != nil {
+			w.WriteHeader(http.StatusInternalServerError)
+			w.Write([]byte("500 - Error reading template file"))
+			http.Error(w, err.Error(), http.StatusInternalServerError)
+			return
+		}
+
+		// exec!
+		t.ExecuteTemplate(w, "users", data)
+	default:
+		http.Redirect(w, r, "/", http.StatusMethodNotAllowed)
+	}
+}
+
+func profileHandler(w http.ResponseWriter, r *http.Request) {
+	vars := mux.Vars(r)
+	id, err := strconv.Atoi(vars["id"])
+	if err != nil {
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write([]byte("500 - Error reading template file"))
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+
+	switch r.Method {
+	case "GET":
+		// define data
+		data := map[string]interface{}{}
+		data["pagelink1"] = Link{"user", "/user"}
+		data["pagelink1options"] = []Link{
+			{Name: "bot", Target: "/bot"},
+			{Name: "battle", Target: "/battle"},
+		}
+
+		session, _ := globalState.sessions.Get(r, "session")
+		username := session.Values["username"].(string)
+
+		target_user, err := UserGetUserFromID(id)
+		if err != nil {
+			// w.WriteHeader(http.StatusUnauthorized)
+			// w.Write([]byte("500 - Error reading template file"))
+			// http.Error(w, err.Error(), http.StatusInternalServerError)
+			data["err"] = "Error getting with the given id"
+		}
+
+		if username != target_user.Name {
+			// w.WriteHeader(http.StatusInternalServerError)
+			// w.Write([]byte("500 - Error reading template file"))
+			// http.Error(w, err.Error(), http.StatusInternalServerError)
+			data["err"] = "You aren't allowed to edit any user except yourself"
+		}
+
+		editing_user, err := UserGetUserFromUsername(username)
+		if err != nil {
+			data["err"] = "Coulnd't get a user for that id"
+		}
+
+		data["user"] = editing_user
+		data["target_user"] = target_user
+
+		data["pagelink2"] = Link{target_user.Name, fmt.Sprintf("/%d", id)}
+		allUserNames, err := UserGetAll()
+		var opts []Link
+		for _, user := range allUserNames {
+			opts = append(opts, Link{Name: user.Name, Target: fmt.Sprintf("/%d", user.ID)})
+		}
+		data["pagelink2options"] = opts
+
+		data["pagelink3"] = Link{"profile", "/profile"}
+
+		if username != "" {
+			data["username"] = username
+		}
+
+		// get the template
+		t, err := template.ParseGlob("./templates/*.html")
+		if err != nil {
+			w.WriteHeader(http.StatusInternalServerError)
+			w.Write([]byte("500 - Error reading template file"))
+			http.Error(w, err.Error(), http.StatusInternalServerError)
+			return
+		}
+
+		// exec!
+		t.ExecuteTemplate(w, "profile", data)
+
+	case "POST":
+		session, _ := globalState.sessions.Get(r, "session")
+		orig_username := session.Values["username"].(string)
+
+		// parse the post parameters
+		r.ParseForm()
+		new_username := r.Form.Get("username")
+		password1 := r.Form.Get("password1")
+		password2 := r.Form.Get("password2")
+
+		if len(new_username) >= 64 {
+			w.WriteHeader(http.StatusInternalServerError)
+			w.Write([]byte("500 - Oi', Backend here! Please enter less than 64 chars!"))
+			return
+		}
+
+		if len(password1) >= 256 {
+			w.WriteHeader(http.StatusInternalServerError)
+			w.Write([]byte("500 - Oi', Backend here! Don't overdo with the length please!"))
+			return
+		}
+
+		if password1 != password2 {
+			w.WriteHeader(http.StatusInternalServerError)
+			w.Write([]byte("500 - Oi', Backend here! The passwords you entered don't match!"))
+			return
+		}
+
+		// first update the password, as they might have also changed their
+		// username
+		if password1 != "" {
+			passwordHash := argon2.IDKey([]byte(password1), []byte(salt), 1, 64*1024, 4, 32)
+
+			err := UserUpdatePasswordHash(orig_username, passwordHash)
+			if err != nil {
+				w.WriteHeader(http.StatusInternalServerError)
+				w.Write([]byte("500 - We had problems inserting your new pw into the DB"))
+				return
+			}
+		}
+
+		if new_username != "" {
+			err := UserUpdateUsername(id, new_username)
+			if err != nil {
+				w.WriteHeader(http.StatusInternalServerError)
+				w.Write([]byte("500 - We had problems inserting your new uname into the DB"))
+				return
+			}
+
+			// after changing the username, we also have to update the username
+			//  stored in the session
+			session.Values["username"] = new_username
+			err = session.Save(r, w)
+			if err != nil {
+				http.Error(w, err.Error(), http.StatusInternalServerError)
+				return
+			}
+		}
+
+		http.Redirect(w, r, fmt.Sprintf("/user/%d/profile", id), http.StatusSeeOther)
+		return
+	default:
+		http.Redirect(w, r, "/", http.StatusMethodNotAllowed)
+	}
+}