diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/arch.go | 41 | ||||
-rw-r--r-- | src/battle.go | 846 | ||||
-rw-r--r-- | src/bit.go | 42 | ||||
-rw-r--r-- | src/bot.go | 842 | ||||
-rw-r--r-- | src/db.go | 127 | ||||
-rw-r--r-- | src/http.go | 55 | ||||
-rw-r--r-- | src/log.go | 34 | ||||
-rw-r--r-- | src/main.go | 75 | ||||
-rw-r--r-- | src/r2.go | 24 | ||||
-rw-r--r-- | src/sql.go | 31 | ||||
-rw-r--r-- | src/sqlitestore.go | 284 | ||||
-rw-r--r-- | src/user.go | 666 |
12 files changed, 3067 insertions, 0 deletions
diff --git a/src/arch.go b/src/arch.go new file mode 100644 index 0000000..52dc516 --- /dev/null +++ b/src/arch.go @@ -0,0 +1,41 @@ +package main + +type Arch struct { + ID int + Name string + Enabled bool +} + +////////////////////////////////////////////////////////////////////////////// +// GENERAL PURPOSE + +func ArchGetAll() ([]Arch, error) { + return globalState.GetAllArchs() +} + +////////////////////////////////////////////////////////////////////////////// +// DATABASE + +func (s *State) GetAllArchs() ([]Arch, error) { + rows, err := s.db.Query("SELECT id, name FROM archs") + defer rows.Close() + if err != nil { + return nil, err + } + + var archs []Arch + for rows.Next() { + var arch Arch + if err := rows.Scan(&arch.ID, &arch.Name); err != nil { + return archs, err + } + archs = append(archs, arch) + } + if err = rows.Err(); err != nil { + return archs, err + } + return archs, nil +} + +////////////////////////////////////////////////////////////////////////////// +// HTTP diff --git a/src/battle.go b/src/battle.go new file mode 100644 index 0000000..3d418f2 --- /dev/null +++ b/src/battle.go @@ -0,0 +1,846 @@ +package main + +import ( + "fmt" + "html/template" + "log" + "net/http" + "strconv" + "strings" + "time" + + "github.com/gorilla/mux" +) + +type Battle struct { + ID int + Name string + Bots []Bot + Owners []User + Public bool + Archs []Arch + Bits []Bit +} + +////////////////////////////////////////////////////////////////////////////// +// GENERAL PURPOSE + +func BattleGetAll() ([]Battle, error) { + return globalState.GetAllBattles() +} + +func BattleCreate(name string, public bool) (int, error) { + return globalState.InsertBattle(Battle{Name: name, Public: public}) +} + +func BattleLinkBot(botid int, battleid int) error { + return globalState.LinkBotBattle(botid, battleid) +} + +func BattleGetByIdDeep(id int) (Battle, error) { + return globalState.GetBattleByIdDeep(id) +} + +func BattleUpdate(battle Battle) error { + return globalState.UpdateBattle(battle) +} + +func BattleLinkArchIDs(battleid int, archIDs []int) error { + return globalState.LinkArchIDsToBattle(battleid, archIDs) +} + +func BattleLinkBitIDs(battleid int, bitIDs []int) error { + return globalState.LinkBitIDsToBattle(battleid, bitIDs) +} + +////////////////////////////////////////////////////////////////////////////// +// DATABASE + +func (s *State) InsertBattle(battle Battle) (int, error) { + res, err := s.db.Exec("INSERT INTO battles VALUES(NULL,?,?,?);", time.Now(), battle.Name, battle.Public) + if err != nil { + return -1, err + } + + var id int64 + if id, err = res.LastInsertId(); err != nil { + return -1, err + } + return int(id), nil +} + +func (s *State) UpdateBattle(battle Battle) error { + _, err := s.db.Exec("UPDATE battles SET name=?, public=? WHERE id=?", battle.Name, battle.Public, battle.ID) + if err != nil { + return err + } + return nil +} + +func (s *State) LinkBotBattle(botid int, battleid int) error { + _, err := s.db.Exec("INSERT INTO bot_battle_rel VALUES (?, ?)", botid, battleid) + if err != nil { + log.Println("Error linking bot to battle: ", err) + return err + } else { + return nil + } +} + +func (s *State) LinkArchIDsToBattle(battleid int, archIDs []int) error { + // delete preexisting links + _, err := s.db.Exec("DELETE FROM arch_battle_rel WHERE battle_id=?;", battleid) + if err != nil { + return err + } + + // yes, we're building this by hand, but as we only insert int's I'm just confident that whoever + // gets some sqli here just deserves it :D + query := "INSERT INTO arch_battle_rel (arch_id, battle_id) VALUES" + for idx, id := range archIDs { + query += fmt.Sprintf("(%d, %d)", id, battleid) + if idx != len(archIDs)-1 { + query += ", " + } + } + query += ";" + log.Println(query) + + _, err = s.db.Exec(query) + if err != nil { + log.Println("LinkArchIDsToBattle err: ", err) + return err + } else { + return nil + } +} + +func (s *State) LinkBitIDsToBattle(battleid int, bitIDs []int) error { + // delete preexisting links + _, err := s.db.Exec("DELETE FROM bit_battle_rel WHERE battle_id=?;", battleid) + if err != nil { + return err + } + + // yes, we're building this by hand, but as we only insert int's I'm just confident that whoever + // gets some sqli here just deserves it :D + query := "INSERT INTO bit_battle_rel (bit_id, battle_id) VALUES" + for idx, id := range bitIDs { + query += fmt.Sprintf("(%d, %d)", id, battleid) + if idx != len(bitIDs)-1 { + query += ", " + } + } + query += ";" + log.Println(query) + + _, err = s.db.Exec(query) + if err != nil { + log.Println("LinkBitIDsToBattle err: ", err) + return err + } else { + return nil + } +} + +func (s *State) GetAllBattles() ([]Battle, error) { + rows, err := s.db.Query("SELECT id, name FROM battles;") + defer rows.Close() + if err != nil { + return nil, err + } + + var battles []Battle + for rows.Next() { + var battle Battle + if err := rows.Scan(&battle.ID, &battle.Name); err != nil { + log.Println(err) + return battles, err + } + battles = append(battles, battle) + } + if err = rows.Err(); err != nil { + log.Println(err) + return battles, err + } + return battles, nil +} + +func (s *State) GetBattleByIdDeep(id int) (Battle, error) { + var battleid int + var battlename string + var battlepublic bool + + var botids string + var botnames string + + var userids string + var usernames string + + var archids string + var archnames string + + var bitids string + var bitnames string + + // battles have associated bots and users, we're fetching 'em all! + + // This fetches the battles and relates the associated bots, users, archs and bits + + err := s.db.QueryRow(` + SELECT DISTINCT + ba.id, ba.name, ba.public, + COALESCE(group_concat(DISTINCT bb.bot_id), ""), + COALESCE(group_concat(DISTINCT bo.name), ""), + COALESCE(group_concat(DISTINCT ub.user_id), ""), + COALESCE(group_concat(DISTINCT us.name), ""), + COALESCE(group_concat(DISTINCT ab.arch_id), ""), + COALESCE(group_concat(DISTINCT ar.name), ""), + COALESCE(group_concat(DISTINCT bitbat.bit_id), ""), + COALESCE(group_concat(DISTINCT bi.name), "") + FROM battles ba + + LEFT JOIN bot_battle_rel bb ON bb.battle_id = ba.id + LEFT JOIN bots bo ON bo.id = bb.bot_id + + LEFT JOIN user_battle_rel ub ON ub.battle_id = ba.id + LEFT JOIN users us ON us.id = ub.user_id + + LEFT JOIN arch_battle_rel ab ON ab.battle_id = ba.id + LEFT JOIN archs ar ON ar.id = ab.arch_id + + LEFT JOIN bit_battle_rel bitbat ON bitbat.battle_id = ba.id + LEFT JOIN bits bi ON bi.id = bitbat.bit_id + + WHERE ba.id=? + GROUP BY ba.id; + `, id).Scan(&battleid, &battlename, &battlepublic, &botids, &botnames, &userids, &usernames, &archids, &archnames, &bitids, &bitnames) + if err != nil { + log.Println("Err making GetBattleByID query: ", err) + return Battle{}, err + } + + log.Println("battleid: ", battleid) + log.Println("battlename: ", battlename) + log.Println("battlepublic: ", battlepublic) + log.Println("botids: ", botids) + log.Println("botnames: ", botnames) + log.Println("userids: ", userids) + log.Println("usernames: ", usernames) + log.Println("archids: ", archids) + log.Println("archnames: ", archnames) + log.Println("bitids: ", bitids) + log.Println("bitnames: ", bitnames) + + // The below is a wonderful examle of how golang could profit from macros + // I should just have done this all in common lisp tbh. + + // assemble the bots + botIDList := strings.Split(botids, ",") + botNameList := strings.Split(botnames, ",") + + // Using strings.Split on an empty string returns a list containing + // nothing with a length of one + // https://go.dev/play/p/N1D-OcwiVAs + + var bots []Bot + if botIDList[0] != "" { + for i, _ := range botIDList { + id, err := strconv.Atoi(botIDList[i]) + if err != nil { + log.Println("Err handling bots: ", err) + return Battle{}, err + } + bots = append(bots, Bot{id, botNameList[i], "", []User{}, []Arch{}, []Bit{}}) + } + } else { + bots = []Bot{} + } + + // assemble the users + userIDList := strings.Split(userids, ",") + userNameList := strings.Split(usernames, ",") + + var users []User + if userIDList[0] != "" { + for i, _ := range userIDList { + id, err := strconv.Atoi(userIDList[i]) + if err != nil { + log.Println("Err handling users: ", err) + return Battle{}, err + } + users = append(users, User{id, userNameList[i], []byte{}}) + } + } else { + users = []User{} + } + + // assemble the archs + archIDList := strings.Split(archids, ",") + archNameList := strings.Split(archnames, ",") + + var archs []Arch + if archIDList[0] != "" { + for i, _ := range archIDList { + id, err := strconv.Atoi(archIDList[i]) + if err != nil { + log.Println("Err handling archs: ", err) + return Battle{}, err + } + archs = append(archs, Arch{id, archNameList[i], true}) + } + } else { + archs = []Arch{} + } + + // assemble the bits + bitIDList := strings.Split(bitids, ",") + bitNameList := strings.Split(bitnames, ",") + + var bits []Bit + if bitIDList[0] != "" { + for i, _ := range bitIDList { + id, err := strconv.Atoi(bitIDList[i]) + if err != nil { + log.Println("Err handling bits: ", err) + return Battle{}, err + } + bits = append(bits, Bit{id, bitNameList[i], true}) + } + } else { + bits = []Bit{} + } + + // return it all! + switch { + case err != nil: + log.Println("Overall err in the GetBattleByID func: ", err) + return Battle{}, err + default: + return Battle{ + ID: battleid, + Name: battlename, + Bots: bots, + Owners: users, + Public: battlepublic, + Archs: archs, + Bits: bits, + }, nil + } +} + +////////////////////////////////////////////////////////////////////////////// +// HTTP + +func battlesHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "GET": + // define data + data := map[string]interface{}{} + data["pagelink1"] = Link{Name: "battle", Target: "/battle"} + data["pagelink1options"] = []Link{ + {Name: "bot", Target: "/bot"}, + {Name: "user", Target: "/user"}, + } + data["pagelinknext"] = []Link{ + {Name: "new", Target: "/new"}, + } + + // sessions + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"].(string) + + // get the user + user, err := UserGetUserFromUsername(username) + if err != nil { + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } else { + data["user"] = user + } + + // get all battles + battles, err := BattleGetAll() + data["battles"] = battles + + // get the template + t, err := template.ParseGlob("./templates/*.html") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // exec! + t.ExecuteTemplate(w, "battles", data) + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} + +func battleNewHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "GET": + // define data + data := map[string]interface{}{} + + // breadcrumb foo + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"].(string) + data["pagelink1"] = Link{Name: "battle", Target: "/battle"} + data["pagelink1options"] = []Link{ + {Name: "user", Target: "/user"}, + {Name: "bot", Target: "/bot"}, + } + data["pagelink2"] = Link{Name: "new", Target: "/new"} + data["pagelink2options"] = []Link{ + {Name: "list", Target: ""}, + } + + // display errors passed via query parameters + queryres := r.URL.Query().Get("err") + if queryres != "" { + data["res"] = queryres + } + + // get data needed + user, err := UserGetUserFromUsername(username) + if err != nil { + data["err"] = "Could not fetch the user" + } else { + data["user"] = user + } + + archs, err := ArchGetAll() + if err != nil { + data["err"] = "Could not fetch the archs" + } else { + data["archs"] = archs + } + + bits, err := BitGetAll() + if err != nil { + data["err"] = "Could not fetch the bits" + } else { + data["bits"] = bits + } + + // get the template + t, err := template.ParseGlob("./templates/*.html") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // exec! + t.ExecuteTemplate(w, "battleNew", data) + + case "POST": + // parse the post parameters + r.ParseForm() + name := r.Form.Get("name") + + var public bool + query_public := r.Form.Get("public") + if query_public == "on" { + public = true + } + + // gather the information from the arch and bit selection + var archIDs []int + var bitIDs []int + + for k, _ := range r.Form { + if strings.HasPrefix(k, "arch-") { + id, err := strconv.Atoi(strings.TrimPrefix(k, "arch-")) + if err != nil { + msg := "ERROR: Invalid arch id" + http.Redirect(w, r, fmt.Sprintf("/battle/new?res=%s", msg), http.StatusSeeOther) + return + } + archIDs = append(archIDs, id) + } + if strings.HasPrefix(k, "bit-") { + id, err := strconv.Atoi(strings.TrimPrefix(k, "bit-")) + if err != nil { + msg := "ERROR: Invalid bit id" + http.Redirect(w, r, fmt.Sprintf("/battle/new?res=%s", msg), http.StatusSeeOther) + return + } + bitIDs = append(bitIDs, id) + } + } + + if name != "" { + // create the battle itself + log.Println("Creating battle") + battleid, err := BattleCreate(name, public) + if err != nil { + log.Println("Error creating the battle using BattleCreate(): ", err) + msg := "ERROR: Could not create due to internal reasons" + http.Redirect(w, r, fmt.Sprintf("/battle/new?res=%s", msg), http.StatusSeeOther) + return + } + + // link archs to battle + err = BattleLinkArchIDs(battleid, archIDs) + if err != nil { + log.Println("Error linking the arch ids to the battle: ", err) + msg := "ERROR: Could not create due to internal reasons" + http.Redirect(w, r, fmt.Sprintf("/battle/new?res=%s", msg), http.StatusSeeOther) + return + } + + // link bits to battle + err = BattleLinkBitIDs(battleid, bitIDs) + if err != nil { + log.Println("Error linking the bit ids to the battle: ", err) + msg := "ERROR: Could not create due to internal reasons" + http.Redirect(w, r, fmt.Sprintf("/battle/new?res=%s", msg), http.StatusSeeOther) + return + } + } else { + msg := "ERROR: Please provide a name" + http.Redirect(w, r, fmt.Sprintf("/battle/new?res=%s", msg), http.StatusSeeOther) + return + } + + http.Redirect(w, r, "/battle", http.StatusSeeOther) + return + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} + +// TODO(emile): add user creating battle as default owner +// TODO(emile): allow adding other users as owners to battles +// TODO(emile): implement submitting bots +// TODO(emile): implement running the battle +// TODO(emile): add a "start battle now" button +// TODO(emile): add a "battle starts at this time" field into the battle +// TODO(emile): figure out how time is stored and restored with the db +// TODO(emile): do some magic to display the current fight backlog with all info + +func battleSingleHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + battleid, err := strconv.Atoi(vars["id"]) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Invalid battle id")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + switch r.Method { + case "GET": + // define data + data := map[string]interface{}{} + data["pagelink1"] = Link{"battle", "/battle"} + data["pagelink1options"] = []Link{ + {Name: "user", Target: "/user"}, + {Name: "bot", Target: "/bot"}, + } + + // display errors passed via query parameters + queryres := r.URL.Query().Get("res") + if queryres != "" { + data["res"] = queryres + } + + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"].(string) + + viewer, err := UserGetUserFromUsername(username) + if err != nil { + data["err"] = "Could not get the id four your username... Please contact an admin" + } + data["user"] = viewer + + // get the battle including it's users, bots, archs, bits + battle, err := BattleGetByIdDeep(int(battleid)) + data["battle"] = battle + data["botAmount"] = len(battle.Bots) + data["battleCount"] = (len(battle.Bots) * len(battle.Bots)) * 2 + + // define the breadcrumbs + data["pagelink2"] = Link{battle.Name, fmt.Sprintf("/%d", battle.ID)} + + allbattleNames, err := BattleGetAll() + var opts []Link + for _, battle := range allbattleNames { + opts = append(opts, Link{Name: battle.Name, Target: fmt.Sprintf("/%d", battle.ID)}) + } + data["pagelink2options"] = opts + + // get the bots of the user viewing the page, as they might want to submit them + myBots, err := UserGetBotsUsingUsername(username) + if err != nil { + log.Println("err: ", err) + http.Redirect(w, r, fmt.Sprintf("/battle/%d", battleid), http.StatusSeeOther) + return + } + data["myBots"] = myBots + + // get all architectures and set the enable flag on the ones that are enabled in the battle + archs, err := ArchGetAll() + if err != nil { + data["err"] = "Could not fetch the archs" + } else { + data["archs"] = archs + } + + for i, a := range archs { + for _, b := range battle.Archs { + if a.ID == b.ID { + archs[i].Enabled = true + } + } + } + + // get all bits and set the enable flag on the ones that are enabled in the battle + bits, err := BitGetAll() + if err != nil { + data["err"] = "Could not fetch the bits" + } else { + data["bits"] = bits + } + + for i, a := range bits { + for _, b := range battle.Bits { + if a.ID == b.ID { + bits[i].Enabled = true + } + } + } + + // check if we're allowed to edit + editable := false + for _, owner := range battle.Owners { + if owner.ID == viewer.ID { + editable = true + } + } + if editable == true { + data["editable"] = true + } + + // get the template + t, err := template.ParseGlob("./templates/*.html") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // exec! + t.ExecuteTemplate(w, "battleSingle", data) + + case "POST": + log.Println("POST!") + // checking if the user submitting the battle information is allowed to do so + // session, _ := globalState.sessions.Get(r, "session") + // username := session.Values["username"].(string) + + // get the user submitting + // log.Println("Getting the user submitting the change request...") + // requesting_user, err := UserGetUserFromUsername(username) + // if err != nil { + // log.Println("err: ", err) + // http.Redirect(w, r, fmt.Sprintf("/battle/%d", battleid), http.StatusSeeOther) + // return + // } + + // get the users the battle belongs to + // log.Println("Getting the user the battle belongs to...") + // orig_battle, err := BattleGetByIdDeep(int(battleid)) + // if err != nil { + // log.Println("err: ", err) + // http.Redirect(w, r, fmt.Sprintf("/battle/%d", battleid), http.StatusSeeOther) + // return + // } + + // check if the user submitting the change request is within the users the battle belongs to + // log.Println("Checking if edit is allowed...") + // allowed_to_edit := false + // for _, user := range orig_battle.Owners { + // if user.ID == requesting_user.ID { + // allowed_to_edit = true + // } + // } + + // if allowed_to_edit == false { + // msg := "ERROR: You aren't allowed to edit this battle!" + // http.Redirect(w, r, fmt.Sprintf("/battle/%d?res=%s", battleid, msg), http.StatusSeeOther) + // return + // } + + // at this point, we're sure the user is allowed to edit the battle + + r.ParseForm() + + log.Println("r.Form: ", r.Form) + form_name := r.Form.Get("name") + + var public bool + if r.Form.Get("public") == "on" { + public = true + } + + // gather the information from the arch and bit selection + var archIDs []int + var bitIDs []int + + for k, _ := range r.Form { + if strings.HasPrefix(k, "arch-") { + id, err := strconv.Atoi(strings.TrimPrefix(k, "arch-")) + if err != nil { + msg := "ERROR: Invalid arch id" + http.Redirect(w, r, fmt.Sprintf("/battle/%d?res=%s#settings", battleid, msg), http.StatusSeeOther) + return + } + archIDs = append(archIDs, id) + } + if strings.HasPrefix(k, "bit-") { + id, err := strconv.Atoi(strings.TrimPrefix(k, "bit-")) + if err != nil { + msg := "ERROR: Invalid bit id" + http.Redirect(w, r, fmt.Sprintf("/battle/%d?res=%s#settings", battleid, msg), http.StatusSeeOther) + return + } + bitIDs = append(bitIDs, id) + } + } + + // link archs to battle + err = BattleLinkArchIDs(battleid, archIDs) + if err != nil { + log.Println("Error linking the arch ids to the battle: ", err) + msg := "ERROR: Could not create due to internal reasons" + http.Redirect(w, r, fmt.Sprintf("/battle/%d?res=%s#settings", battleid, msg), http.StatusSeeOther) + return + } + + // link bits to battle + err = BattleLinkBitIDs(battleid, bitIDs) + if err != nil { + log.Println("Error linking the bit ids to the battle: ", err) + msg := "ERROR: Could not create due to internal reasons" + http.Redirect(w, r, fmt.Sprintf("/battle/%d?res=%s#settings", battleid, msg), http.StatusSeeOther) + return + } + + new_battle := Battle{int(battleid), form_name, []Bot{}, []User{}, public, []Arch{}, []Bit{}} + + log.Println("Updating battle...") + err = BattleUpdate(new_battle) + if err != nil { + log.Println("err: ", err) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error inserting battle into db")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + http.Redirect(w, r, fmt.Sprintf("/battle/%d?res=Success!#settings", battleid), http.StatusSeeOther) + + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} + +func battleSubmitHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + battleid, err := strconv.Atoi(vars["id"]) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Invalid battle id")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + switch r.Method { + case "POST": + r.ParseForm() + + log.Println("Adding bot to battle", battleid) + log.Println(r.Form) + + // get all the form values that contain the bot that shall be submitted + var botIDs []int + for k, _ := range r.Form { + if strings.HasPrefix(k, "bot-") { + id, err := strconv.Atoi(strings.TrimPrefix(k, "bot-")) + if err != nil { + msg := "ERROR: Invalid bot supplied" + http.Redirect(w, r, fmt.Sprintf("/battle/%d?res=%s", battleid, msg), http.StatusSeeOther) + return + } + botIDs = append(botIDs, id) + } + } + + log.Println(botIDs) + + battle, err := BattleGetByIdDeep(battleid) + if err != nil { + msg := "ERROR: Couln't get the battle with the given id" + http.Redirect(w, r, fmt.Sprintf("/battle/%d?res=%s", battleid, msg), http.StatusSeeOther) + return + } + log.Println(battle) + + // for all bots, get their bits and arch and compare them to the one of the battle + for _, id := range botIDs { + bot, err := BotGetById(id) + if err != nil { + msg := fmt.Sprintf("ERROR: Couldn't get bot with id %d", id) + http.Redirect(w, r, fmt.Sprintf("/battle/%d?res=%s", battleid, msg), http.StatusSeeOther) + return + } + + var archValid bool = false + for _, battle_arch := range battle.Archs { + for _, bot_arch := range bot.Archs { + if battle_arch.ID == bot_arch.ID { + archValid = true + } + } + } + + var bitValid bool = false + for _, battle_bit := range battle.Bits { + for _, bot_bit := range bot.Bits { + if battle_bit.ID == bot_bit.ID { + bitValid = true + } + } + } + + if archValid && bitValid { + log.Printf("arch and bit valid, adding bot with id %d to battle with id %d\n", id, battleid) + BattleLinkBot(id, battleid) + } else { + if archValid == false { + msg := "Bot has an invalid architecture!" + http.Redirect(w, r, fmt.Sprintf("/battle/%d?res=%s", battleid, msg), http.StatusSeeOther) + return + } + if bitValid == false { + msg := "Bot has an invalid 'bit-ness'!" + http.Redirect(w, r, fmt.Sprintf("/battle/%d?res=%s", battleid, msg), http.StatusSeeOther) + return + } + } + + log.Println(bot) + } + msg := "Success!" + http.Redirect(w, r, fmt.Sprintf("/battle/%d?res=%s", battleid, msg), http.StatusSeeOther) + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} diff --git a/src/bit.go b/src/bit.go new file mode 100644 index 0000000..33c693a --- /dev/null +++ b/src/bit.go @@ -0,0 +1,42 @@ +package main + +// struct element names can't start with nums... +type Bit struct { + ID int + Name string + Enabled bool +} + +////////////////////////////////////////////////////////////////////////////// +// GENERAL PURPOSE + +func BitGetAll() ([]Bit, error) { + return globalState.GetAllBits() +} + +////////////////////////////////////////////////////////////////////////////// +// DATABASE + +func (s *State) GetAllBits() ([]Bit, error) { + rows, err := s.db.Query("SELECT id, name FROM bits") + defer rows.Close() + if err != nil { + return nil, err + } + + var bit []Bit + for rows.Next() { + var arch Bit + if err := rows.Scan(&arch.ID, &arch.Name); err != nil { + return bit, err + } + bit = append(bit, arch) + } + if err = rows.Err(); err != nil { + return bit, err + } + return bit, nil +} + +////////////////////////////////////////////////////////////////////////////// +// HTTP diff --git a/src/bot.go b/src/bot.go new file mode 100644 index 0000000..1a0d342 --- /dev/null +++ b/src/bot.go @@ -0,0 +1,842 @@ +package main + +import ( + "fmt" + "html/template" + "log" + "net/http" + "strconv" + "strings" + "time" + + "github.com/gorilla/mux" + "github.com/radareorg/r2pipe-go" +) + +type Bot struct { + ID int + Name string + Source string + Users []User + + Archs []Arch + Bits []Bit +} + +////////////////////////////////////////////////////////////////////////////// +// GENERAL PURPOSE + +func BotCreate(name string, source string) (int, error) { + return globalState.InsertBot(Bot{Name: name, Source: source}) +} + +func BotUpdate(botid int, name string, source string) error { + return globalState.UpdateBot(Bot{ID: botid, Name: name, Source: source}) +} + +func BotGetById(id int) (Bot, error) { + return globalState.GetBotById(id) +} + +func BotGetAll() ([]Bot, error) { + return globalState.GetAllBot() +} + +func BotLinkArchIDs(botid int, archIDs []int) error { + return globalState.LinkArchIDsToBot(botid, archIDs) +} + +func BotLinkBitIDs(botid int, bitIDs []int) error { + return globalState.LinkBitIDsToBot(botid, bitIDs) +} + +////////////////////////////////////////////////////////////////////////////// +// DATABASE + +func (s *State) InsertBot(bot Bot) (int, error) { + res, err := s.db.Exec("INSERT INTO bots VALUES(NULL,?,?,?);", time.Now(), bot.Name, bot.Source) + if err != nil { + return 0, err + } + + var id int64 + if id, err = res.LastInsertId(); err != nil { + return 0, err + } + return int(id), nil +} + +func (s *State) UpdateBot(bot Bot) error { + _, err := s.db.Exec("UPDATE bots SET name=?, source=? WHERE id=?", bot.Name, bot.Source, bot.ID) + if err != nil { + return err + } + return nil +} + +func (s *State) GetBotById(id int) (Bot, error) { + var botid int + var botname string + var botsource string + + var ownerids string + var ownernames string + + var archids string + var archnames string + + var bitids string + var bitnames string + + err := s.db.QueryRow(` + SELECT + bo.id, bo.name, bo.source, + COALESCE(group_concat(ub.user_id), ""), + COALESCE(group_concat(us.name), ""), + COALESCE(group_concat(ab.arch_id), ""), + COALESCE(group_concat(ar.name), ""), + COALESCE(group_concat(bb.bit_id), ""), + COALESCE(group_concat(bi.name), "") + FROM bots bo + + LEFT JOIN user_bot_rel ub ON ub.bot_id = bo.id + LEFT JOIN users us ON us.id = ub.user_id + + LEFT JOIN arch_bot_rel ab ON ab.bot_id = bo.id + LEFT JOIN archs ar ON ar.id = ab.arch_id + + LEFT JOIN bit_bot_rel bb ON bb.bot_id = bo.id + LEFT JOIN bits bi ON bi.id = bb.bit_id + + WHERE bo.id=? + GROUP BY bo.id; + `, id).Scan(&botid, &botname, &botsource, &ownerids, &ownernames, &archids, &archnames, &bitids, &bitnames) + if err != nil { + log.Println(err) + return Bot{}, err + } + + // log.Println("botid: ", botid) + // log.Println("botname: ", botname) + // log.Println("botsource: ", botsource) + // log.Println("ownerids: ", ownerids) + // log.Println("ownernames: ", ownernames) + // log.Println("archid: ", archids) + // log.Println("archname: ", archnames) + // log.Println("bitid: ", bitids) + // log.Println("bitname: ", bitnames) + + ownerIDList := strings.Split(ownerids, ",") + ownerNameList := strings.Split(ownernames, ",") + + var users []User + for i, _ := range ownerIDList { + id, err := strconv.Atoi(ownerIDList[i]) + if err != nil { + log.Println("ERR1: ", err) + return Bot{}, err + } + users = append(users, User{ID: id, Name: ownerNameList[i], PasswordHash: nil}) + } + + // assemble the archs + archIDList := strings.Split(archids, ",") + archNameList := strings.Split(archnames, ",") + + var archs []Arch + if archIDList[0] != "" { + for i, _ := range archIDList { + id, err := strconv.Atoi(archIDList[i]) + if err != nil { + log.Println("Err handling archs: ", err) + return Bot{}, err + } + archs = append(archs, Arch{id, archNameList[i], true}) + } + } else { + archs = []Arch{} + } + + // assemble the bits + bitIDList := strings.Split(bitids, ",") + bitNameList := strings.Split(bitnames, ",") + + var bits []Bit + if bitIDList[0] != "" { + for i, _ := range bitIDList { + id, err := strconv.Atoi(bitIDList[i]) + if err != nil { + log.Println("Err handling bits: ", err) + return Bot{}, err + } + bits = append(bits, Bit{id, bitNameList[i], true}) + } + } else { + bits = []Bit{} + } + + switch { + case err != nil: + log.Println("ERR4: ", err) + return Bot{}, err + default: + return Bot{botid, botname, botsource, users, archs, bits}, nil + } +} + +func (s *State) UpdateBotSource(name string, source string) error { + _, err := s.db.Exec("UPDATE bots SET source=? WHERE name=?", source, name) + if err != nil { + return err + } else { + return nil + } +} + +func (s *State) UpdateBotName(orig_name string, new_name string) error { + _, err := s.db.Exec("UPDATE bots SET name=? WHERE name=?", new_name, orig_name) + if err != nil { + return err + } else { + return nil + } +} + +// Returns the users belonging to the given bot +func (s *State) GetBotUsers(botid int) ([]User, error) { + rows, err := s.db.Query("SELECT id, name FROM users u LEFT JOIN user_bot_rel ub ON ub.user_id = u.id WHERE ub.bot_id=?", botid) + defer rows.Close() + if err != nil { + return nil, err + } + + var users []User + for rows.Next() { + var user User + if err := rows.Scan(&user.ID, &user.Name); err != nil { + return users, err + } + users = append(users, user) + } + if err = rows.Err(); err != nil { + return users, err + } + return users, nil +} + +// Returns the users belonging to the given bot +func (s *State) GetAllBot() ([]Bot, error) { + rows, err := s.db.Query("SELECT id, name FROM bots;") + defer rows.Close() + if err != nil { + return nil, err + } + + var bots []Bot + for rows.Next() { + var bot Bot + if err := rows.Scan(&bot.ID, &bot.Name); err != nil { + return bots, err + } + bots = append(bots, bot) + } + if err = rows.Err(); err != nil { + return bots, err + } + return bots, nil +} + +// Returns the users belonging to the given bot +func (s *State) GetAllBotsWithUsers() ([]Bot, error) { + rows, err := s.db.Query(`SELECT + b.id, b.name, b.source, group_concat(ub.user_id), group_concat(u.name) + FROM bots b + LEFT JOIN user_bot_rel ub ON ub.bot_id = b.id + LEFT JOIN users u ON ub.user_id = u.id + GROUP BY b.id;`) + defer rows.Close() + if err != nil { + return nil, err + } + + var bots []Bot + for rows.Next() { + var bot Bot + var userIDListString string + var usernameListString string + + err := rows.Scan(&bot.ID, &bot.Name, &bot.Source, &userIDListString, &usernameListString) + if err != nil { + return nil, err + } + + userIDList := strings.Split(userIDListString, ",") + usernameList := strings.Split(usernameListString, ",") + + var users []User + for i, _ := range userIDList { + id, err := strconv.Atoi(userIDList[i]) + if err != nil { + return nil, err + } + users = append(users, User{ID: id, Name: usernameList[i], PasswordHash: nil}) + } + bot.Users = users + + bots = append(bots, bot) + } + if err = rows.Err(); err != nil { + return bots, err + } + return bots, nil +} + +func (s *State) LinkArchIDsToBot(botid int, archIDs []int) error { + // delete preexisting links + _, err := s.db.Exec("DELETE FROM arch_bot_rel WHERE bot_id=?;", botid) + if err != nil { + log.Println("Error deleting old arch bot link: ", err) + return err + } + + // yes, we're building this by hand, but as we only insert int's I'm just confident that whoever + // gets some sqli here just deserves it :D + query := "INSERT INTO arch_bot_rel (arch_id, bot_id) VALUES" + for idx, id := range archIDs { + query += fmt.Sprintf("(%d, %d)", id, botid) + if idx != len(archIDs)-1 { + query += ", " + } + } + query += ";" + log.Println(query) + + _, err = s.db.Exec(query) + if err != nil { + log.Println("LinkArchIDsToBot err: ", err) + return err + } else { + return nil + } +} + +func (s *State) LinkBitIDsToBot(botid int, bitIDs []int) error { + // delete preexisting links + _, err := s.db.Exec("DELETE FROM bit_bot_rel WHERE bot_id=?;", botid) + if err != nil { + log.Println("Error deleting old bit bot link: ", err) + return err + } + + // yes, we're building this by hand, but as we only insert int's I'm just confident that whoever + // gets some sqli here just deserves it :D + query := "INSERT INTO bit_bot_rel (bit_id, bot_id) VALUES" + for idx, id := range bitIDs { + query += fmt.Sprintf("(%d, %d)", id, botid) + if idx != len(bitIDs)-1 { + query += ", " + } + } + query += ";" + log.Println(query) + + _, err = s.db.Exec(query) + if err != nil { + log.Println("LinkBitIDsToBot err: ", err) + return err + } else { + return nil + } +} + +////////////////////////////////////////////////////////////////////////////// +// HTTP + +func botsHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "GET": + // define data + data := map[string]interface{}{} + + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"] + data["pagelink1"] = Link{"bot", "/bot"} + data["pagelink1options"] = []Link{ + {Name: "user", Target: "/user"}, + {Name: "battle", Target: "/battle"}, + } + data["pagelinknext"] = []Link{ + {Name: "new", Target: "/new"}, + } + + if username == nil { + http.Redirect(w, r, "/login", http.StatusMethodNotAllowed) + } + + user, err := UserGetUserFromUsername(username.(string)) + if err != nil { + data["err"] = "Could not fetch the user" + } else { + data["user"] = user + } + + bots, err := globalState.GetAllBotsWithUsers() + data["bots"] = bots + + // get the template + t, err := template.ParseGlob("./templates/*.html") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // exec! + t.ExecuteTemplate(w, "bots", data) + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} + +func botSingleHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + botid, err := strconv.Atoi(vars["id"]) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Invalid bot id")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + switch r.Method { + case "GET": + // define data + data := map[string]interface{}{} + data["pagelink1"] = Link{"bot", "/bot"} + data["pagelink1options"] = []Link{ + {Name: "user", Target: "/user"}, + {Name: "battle", Target: "/battle"}, + } + + // display errors passed via query parameters + log.Println("[d] Getting previous results") + queryres := r.URL.Query().Get("res") + if queryres != "" { + data["res"] = queryres + } + + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"].(string) + + viewer, err := UserGetUserFromUsername(username) + if err != nil { + data["err"] = "Could not get the id four your username... Please contact an admin" + } + + bot, err := BotGetById(int(botid)) + data["bot"] = bot + data["user"] = viewer + + // open radare without input for building the bot + r2p1, err := r2pipe.NewPipe("--") + if err != nil { + panic(err) + } + defer r2p1.Close() + + src := strings.ReplaceAll(bot.Source, "\r\n", "; ") + radareCommand := fmt.Sprintf("rasm2 -a %s -b %s \"%+v\"", bot.Archs[0].Name, bot.Bits[0].Name, src) + bytecode, err := r2cmd(r2p1, radareCommand) + if err != nil { + data["err"] = "Error assembling the bot" + http.Redirect(w, r, fmt.Sprintf("/bot/%d", botid), http.StatusSeeOther) + return + } + data["bytecode_r2cmd"] = radareCommand + data["bytecode"] = bytecode + + radareCommand = fmt.Sprintf("rasm2 -a %s -b %s -D %+v", bot.Archs[0].Name, bot.Bits[0].Name, bytecode) + disasm, err := r2cmd(r2p1, radareCommand) + if err != nil { + data["err"] = "Error disassembling the bot" + http.Redirect(w, r, fmt.Sprintf("/bot/%d", botid), http.StatusSeeOther) + return + } + data["err"] = "Could not get the id four your username... Please contact an admin" + + data["disasm_r2cmd"] = radareCommand + data["disasm"] = disasm + + // define the breadcrumbs + data["pagelink2"] = Link{bot.Name, fmt.Sprintf("/%d", bot.ID)} + + allBotNames, err := BotGetAll() + var opts []Link + for _, bot := range allBotNames { + opts = append(opts, Link{Name: bot.Name, Target: fmt.Sprintf("/%d", bot.ID)}) + } + data["pagelink2options"] = opts + + editable := false + for _, user := range bot.Users { + if user.ID == viewer.ID { + editable = true + } + } + if editable == true { + data["editable"] = true + } + + // get all architectures and set the enable flag on the ones that are enabled in the battle + archs, err := ArchGetAll() + if err != nil { + data["err"] = "Could not fetch the archs" + } else { + data["archs"] = archs + } + + for i, a := range archs { + for _, b := range bot.Archs { + if a.ID == b.ID { + archs[i].Enabled = true + } + } + } + + // get all bits and set the enable flag on the ones that are enabled in the battle + bits, err := BitGetAll() + if err != nil { + data["err"] = "Could not fetch the bits" + } else { + data["bits"] = bits + } + + for i, a := range bits { + for _, b := range bot.Bits { + if a.ID == b.ID { + bits[i].Enabled = true + } + } + } + + // get the template + t, err := template.ParseGlob("./templates/*.html") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // exec! + t.ExecuteTemplate(w, "botSingle", data) + + case "POST": + // checking if the user submitting the bot information is allowed to do so + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"].(string) + + // get the user submitting + log.Println("Getting the user submitting the change request...") + requesting_user, err := UserGetUserFromUsername(username) + if err != nil { + log.Println("err: ", err) + http.Redirect(w, r, fmt.Sprintf("/bot/%d", botid), http.StatusSeeOther) + return + } + + // get the users the bot belongs + log.Println("Getting the user the bot belongs to...") + orig_bot, err := BotGetById(int(botid)) + if err != nil { + log.Println("err: ", err) + http.Redirect(w, r, fmt.Sprintf("/bot/%d", botid), http.StatusSeeOther) + return + } + + // check if the user submitting the change request is within the users the bot belongs to + log.Println("Checking if edit is allowed...") + allowed_to_edit := false + for _, user := range orig_bot.Users { + if user.ID == requesting_user.ID { + allowed_to_edit = true + } + } + + if allowed_to_edit == false { + http.Redirect(w, r, fmt.Sprintf("/bot/%d", botid), http.StatusSeeOther) + return + } + + // at this point, we're sure the user is allowed to edit the bot + r.ParseForm() + name := r.Form.Get("name") + source := r.Form.Get("source") + + var archIDs []int + var bitIDs []int + + for k, _ := range r.Form { + if strings.HasPrefix(k, "arch-") { + id, err := strconv.Atoi(strings.TrimPrefix(k, "arch-")) + if err != nil { + msg := "ERROR: Invalid arch id" + http.Redirect(w, r, fmt.Sprintf("/bot/%d?res=%s", botid, msg), http.StatusSeeOther) + return + } + archIDs = append(archIDs, id) + } + if strings.HasPrefix(k, "bit-") { + id, err := strconv.Atoi(strings.TrimPrefix(k, "bit-")) + if err != nil { + msg := "ERROR: Invalid bit id" + http.Redirect(w, r, fmt.Sprintf("/bot/%d?res=%s", botid, msg), http.StatusSeeOther) + return + } + bitIDs = append(bitIDs, id) + } + } + + if len(archIDs) == 0 { + msg := "ERROR: Please select an architecture" + http.Redirect(w, r, fmt.Sprintf("/bot/%d?res=%s", botid, msg), http.StatusSeeOther) + return + } + if len(archIDs) >= 2 { + msg := "ERROR: Please select ONE architecture" + http.Redirect(w, r, fmt.Sprintf("/bot/%d?res=%s", botid, msg), http.StatusSeeOther) + return + } + + if len(bitIDs) == 0 { + msg := "ERROR: Please select one of the bits" + http.Redirect(w, r, fmt.Sprintf("/bot/%d?res=%s", botid, msg), http.StatusSeeOther) + return + } + if len(bitIDs) >= 2 { + msg := "ERROR: Please select ONE of the bits" + http.Redirect(w, r, fmt.Sprintf("/bot/%d?res=%s", botid, msg), http.StatusSeeOther) + return + } + + // link archs to battle + err = BotLinkArchIDs(botid, archIDs) + if err != nil { + log.Println("Error linking the arch ids to the battle: ", err) + msg := "ERROR: Could not create due to internal reasons" + http.Redirect(w, r, fmt.Sprintf("/bot/%d?res=%s", botid, msg), http.StatusSeeOther) + return + } + + // link bits to battle + err = BotLinkBitIDs(botid, bitIDs) + if err != nil { + log.Println("Error linking the bit ids to the battle: ", err) + msg := "ERROR: Could not create due to internal reasons" + http.Redirect(w, r, fmt.Sprintf("/bot/%d?res=%s", botid, msg), http.StatusSeeOther) + return + } + + if name != "" { + if source != "" { + log.Println("Updating bot...") + err := BotUpdate(botid, name, source) + if err != nil { + log.Println("err: ", err) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error inserting bot into db")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + } + + http.Redirect(w, r, fmt.Sprintf("/bot/%d", botid), http.StatusSeeOther) + + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} + +func botNewHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "GET": + // define data + data := map[string]interface{}{} + + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"].(string) + data["pagelink1"] = Link{Name: "bot", Target: "/bot"} + data["pagelink1options"] = []Link{ + {Name: "user", Target: "/user"}, + {Name: "battle", Target: "/battle"}, + } + data["pagelink2"] = Link{Name: "new", Target: "/new"} + data["pagelink2options"] = []Link{ + {Name: "list", Target: ""}, + } + + // display errors passed via query parameters + log.Println("[d] Getting previous results") + queryres := r.URL.Query().Get("res") + if queryres != "" { + data["res"] = queryres + } + + user, err := UserGetUserFromUsername(username) + if err != nil { + data["err"] = "Could not fetch the user" + } else { + data["user"] = user + } + + archs, err := ArchGetAll() + if err != nil { + data["err"] = "Could not fetch the archs" + } else { + data["archs"] = archs + } + + bits, err := BitGetAll() + if err != nil { + data["err"] = "Could not fetch the bits" + } else { + data["bits"] = bits + } + + // get the template + t, err := template.ParseGlob("./templates/*.html") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // exec! + t.ExecuteTemplate(w, "botNew", data) + + case "POST": + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"].(string) + + // parse the post parameters + r.ParseForm() + log.Println("---") + log.Println(r.Form) + log.Println("---") + + name := r.Form.Get("name") + source := r.Form.Get("source") + + if name == "" { + msg := "ERROR: Please provide a name" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + + if source == "" { + msg := "ERROR: Please provide some source" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + + var archIDs []int + var bitIDs []int + + for k, _ := range r.Form { + if strings.HasPrefix(k, "arch-") { + id, err := strconv.Atoi(strings.TrimPrefix(k, "arch-")) + if err != nil { + msg := "ERROR: Invalid arch id" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + archIDs = append(archIDs, id) + } + if strings.HasPrefix(k, "bit-") { + id, err := strconv.Atoi(strings.TrimPrefix(k, "bit-")) + if err != nil { + msg := "ERROR: Invalid bit id" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + bitIDs = append(bitIDs, id) + } + } + + if len(archIDs) == 0 { + msg := "ERROR: Please select an architecture" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + if len(archIDs) >= 2 { + msg := "ERROR: Please select ONE architecture" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + + if len(bitIDs) == 0 { + msg := "ERROR: Please select one of the bits" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + if len(bitIDs) >= 2 { + msg := "ERROR: Please select ONE of the bits" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + + botid, err := BotCreate(name, source) + if err != nil { + log.Println("Error creating the bot: ", err) + msg := "ERROR: Could not create bot" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + + err = UserLinkBot(username, botid) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error adding the bot to the user")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if len(archIDs) == 0 { + msg := "ERROR: Please select an architecture" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + + // link archs to battle + err = BotLinkArchIDs(botid, archIDs) + if err != nil { + log.Println("Error linking the arch ids to the bot: ", err) + msg := "ERROR: Could not create due to internal reasons" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + + if len(bitIDs) == 0 { + msg := "ERROR: Please select an bits" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + + // link bits to battle + err = BotLinkBitIDs(botid, bitIDs) + if err != nil { + log.Println("Error linking the bit ids to the bot: ", err) + msg := "ERROR: Could not create due to internal reasons" + http.Redirect(w, r, fmt.Sprintf("/bot/new?res=%s", msg), http.StatusSeeOther) + return + } + + http.Redirect(w, r, "/bot", http.StatusSeeOther) + return + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} diff --git a/src/db.go b/src/db.go new file mode 100644 index 0000000..e36cf21 --- /dev/null +++ b/src/db.go @@ -0,0 +1,127 @@ +package main + +import ( + "database/sql" + "log" +) + +const create string = ` +CREATE TABLE IF NOT EXISTS users ( + id INTEGER NOT NULL PRIMARY KEY, + created_at DATETIME NOT NULL, + name TEXT, + passwordHash TEXT +); +CREATE TABLE IF NOT EXISTS bots ( + id INTEGER NOT NULL PRIMARY KEY, + created_at DATETIME NOT NULL, + name TEXT, + source TEXT +); +CREATE TABLE IF NOT EXISTS battles ( + id INTEGER NOT NULL PRIMARY KEY, + created_at DATETIME NOT NULL, + name TEXT, + public BOOLEAN +); +CREATE TABLE IF NOT EXISTS archs ( + id INTEGER NOT NULL PRIMARY KEY, + name TEXT, + UNIQUE(name) +); +INSERT OR IGNORE INTO archs (name) VALUES + ("null"), ("6502"), ("6502.cs"), ("8051"), ("alpha"), ("amd29k"), + ("any.as"), ("any.vasm"), ("arm.nz"), ("arm"), ("avr"), ("bf"), ("bpf.mr"), + ("bpf"), ("chip8"), ("cr16"), ("cris"), ("dalvik"), ("dis"), ("ebc"), + ("evm"), ("fslsp"), ("gb"), ("h8300"), ("i4004"), ("i8080"), ("java"), + ("jdh8"), ("kvx"), ("lh5801"), ("lm32"), ("m680x"), ("m68k"), ("mcore"), + ("mcs96"), ("mips"), ("msp430"), ("nios2"), ("or1k"), ("pic"), ("ppc"), + ("propeller"), ("pyc"), ("riscv"), ("riscv.cs"), ("rsp"), ("s390"), + ("sh"), ("sh.cs"), ("snes"), ("sparc"), ("tms320"), ("tricore"), + ("tricore.cs"), ("v850"), ("vax"), ("wasm"), ("ws"), ("x86"), ("x86.nz"), + ("xap"), ("xcore"), ("arm.gnu"), ("lanai"), ("loongarch"), ("m68k.gnu"), + ("mips.gnu"), ("nds32"), ("pdp11"), ("ppc.gnu"), ("s390.gnu"), + ("sparc.gnu"), ("xtensa"), ("z80") +; + +/* + ("x86-64"), ("Alpha"), ("ARM"), ("AVR"), ("BPF"), ("MIPS"), ("PowerPC"), + ("SPARC"), ("RISC-V"), ("SH"), ("m68k"), ("S390"), ("XCore"), ("CR16"), + ("HPPA"), ("ARC"), ("Blackfin"), ("Z80"), ("H8/300"), ("V810"), ("PDP11"), + ("m680x"), ("V850"), ("CRIS"), ("XAP (CSR)"), ("PIC"), ("LM32"), ("8051"), + ("6502"), ("i4004"), ("i8080"), ("Propeller"), ("EVM"), ("OR1K Tricore"), + ("CHIP-8"), ("LH5801"), ("T8200"), ("GameBoy"), ("SNES"), ("SPC700"), + ("MSP430"), ("Xtensa"), ("xcore"), ("NIOS II"), ("Java"), ("Dalvik"), + ("Pickle"), ("WebAssembly"), ("MSIL"), ("EBC"), ("TMS320"), ("c54x"), ("c55x"), + ("c55+"), ("c64x"), ("Hexagon"), ("Brainfuck"), ("Malbolge"), + ("whitespace"), ("DCPU16"), ("LANAI"), ("lm32"), ("MCORE"), ("mcs96"), + ("RSP"), ("SuperH-4"), ("VAX"), ("KVX"), ("Am29000"), ("LOONGARCH"), + ("JDH8"), ("s390x"), ("STM8.") +*/ + +CREATE TABLE IF NOT EXISTS bits ( + id INTEGER NOT NULL PRIMARY KEY, + name TEXT, + UNIQUE(name) +); +INSERT OR IGNORE INTO bits (name) VALUES + ("8"), ("16"), ("32"), ("64") +; + +CREATE TABLE IF NOT EXISTS user_bot_rel ( + user_id INTEGER, + bot_id INTEGER, + PRIMARY KEY(user_id, bot_id) +); +CREATE TABLE IF NOT EXISTS arch_bot_rel ( + arch_id INTEGER, + bot_id INTEGER, + PRIMARY KEY(arch_id, bot_id) +); +CREATE TABLE IF NOT EXISTS bit_bot_rel ( + bit_id INTEGER, + bot_id INTEGER, + PRIMARY KEY(bit_id, bot_id) +); + +CREATE TABLE IF NOT EXISTS user_battle_rel ( + user_id INTEGER, + battle_id INTEGER, + PRIMARY KEY(user_id, battle_id) +); +CREATE TABLE IF NOT EXISTS bot_battle_rel ( + bot_id INTEGER, + battle_id INTEGER, + PRIMARY KEY(bot_id, battle_id) +); +CREATE TABLE IF NOT EXISTS arch_battle_rel ( + arch_id INTEGER, + battle_id INTEGER, + PRIMARY KEY(arch_id, battle_id) +); +CREATE TABLE IF NOT EXISTS bit_battle_rel ( + bit_id INTEGER, + battle_id INTEGER, + PRIMARY KEY(bit_id, battle_id) +); +` + +type State struct { + db *sql.DB // the database storing the "business data" + sessions *SqliteStore // the database storing sessions +} + +func NewState() (*State, error) { + db, err := sql.Open("sqlite3", database_file) + if err != nil { + log.Println("Error opening the db: ", err) + return nil, err + } + if _, err := db.Exec(create); err != nil { + log.Println("Error creating the tables: ", err) + return nil, err + } + return &State{ + db: db, + }, nil +} diff --git a/src/http.go b/src/http.go new file mode 100644 index 0000000..894542a --- /dev/null +++ b/src/http.go @@ -0,0 +1,55 @@ +package main + +import ( + "html/template" + "net/http" +) + +type Link struct { + Name string + Target string +} + +func indexHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "GET": + // define data + data := map[string]interface{}{} + data["pagelink1"] = "" + data["pagelinknext"] = []Link{ + {Name: "user/", Target: "/user"}, + {Name: "bot/", Target: "/bot"}, + {Name: "battle/", Target: "/battle"}, + } + data["pagelinkauth"] = []Link{ + {Name: "login/", Target: "/login"}, + {Name: "register/", Target: "/register"}, + } + + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"] + + if username != nil { + data["logged_in"] = true + + user, err := UserGetUserFromUsername(username.(string)) + if err != nil { + data["err"] = "Couln't get the user" + } + data["user"] = user + } + + // get the template + t, err := template.ParseGlob("./templates/*.html") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + return + } + + // exec! + t.ExecuteTemplate(w, "index", data) + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} diff --git a/src/log.go b/src/log.go new file mode 100644 index 0000000..5af719a --- /dev/null +++ b/src/log.go @@ -0,0 +1,34 @@ +package main + +import ( + "net/http" + "os" + + "github.com/gorilla/handlers" +) + +// Defines a middleware containing a logfile +// +// This is done to combine gorilla/handlers with gorilla/mux middlewares to +// just use r.Use(logger.Middleware) once instead of adding this to all +// handlers manually (Yes, I'm really missing macros in Go...) +type loggingMiddleware struct { + logFile *os.File +} + +func (l *loggingMiddleware) Middleware(next http.Handler) http.Handler { + return handlers.LoggingHandler(l.logFile, next) +} + +func authMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"] + + if username == nil { + http.Redirect(w, r, "/login", http.StatusSeeOther) + } else { + next.ServeHTTP(w, r) + } + }) +} diff --git a/src/main.go b/src/main.go new file mode 100644 index 0000000..27c066c --- /dev/null +++ b/src/main.go @@ -0,0 +1,75 @@ +package main + +import ( + "log" + "net/http" + "os" + + "github.com/gorilla/mux" +) + +const database_file string = "main.db" +const salt = "oogha3AiH7taimohreeH8Lexoonea5zi" + +var ( + globalState *State +) + +func main() { + + // log init + log.Println("[i] Setting up logging...") + logFile, err := os.OpenFile("server.log", os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0664) + if err != nil { + log.Fatal("Error opening the server.log file: ", err) + } + logger := loggingMiddleware{logFile} + + // db init + log.Println("[i] Setting up Global State Struct...") + s, err := NewState() + if err != nil { + log.Fatal("Error creating the NewState(): ", err) + } + globalState = s + + // session init + log.Println("[i] Setting up Session Storage...") + store, err := NewSqliteStore("./sessions.db", "sessions", "/", 3600, []byte(os.Getenv("SESSION_KEY"))) + if err != nil { + panic(err) + } + globalState.sessions = store + + // HTTP init + log.Println("[i] Setting up HTTP Routes...") + r := mux.NewRouter() + r.Use(logger.Middleware) + + // unauthenticated endpoints + r.HandleFunc("/", indexHandler) + r.HandleFunc("/login", loginHandler) + r.HandleFunc("/register", registerHandler) + r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir("./static")))) + + // endpoints with auth needed + auth_needed := r.PathPrefix("/").Subrouter() + auth_needed.Use(authMiddleware) + auth_needed.HandleFunc("/logout", logoutHandler) + + auth_needed.HandleFunc("/bot", botsHandler) + auth_needed.HandleFunc("/bot/new", botNewHandler) + auth_needed.HandleFunc("/bot/{id}", botSingleHandler) + + auth_needed.HandleFunc("/user", usersHandler) + auth_needed.HandleFunc("/user/{id}", userHandler) + auth_needed.HandleFunc("/user/{id}/profile", profileHandler) + + auth_needed.HandleFunc("/battle", battlesHandler) + auth_needed.HandleFunc("/battle/new", battleNewHandler) + auth_needed.HandleFunc("/battle/{id}", battleSingleHandler) + auth_needed.HandleFunc("/battle/{id}/submit", battleSubmitHandler) + + log.Println("[i] HTTP Server running on port :8080") + log.Fatal(http.ListenAndServe(":8080", r)) +} diff --git a/src/r2.go b/src/r2.go new file mode 100644 index 0000000..6ecd24b --- /dev/null +++ b/src/r2.go @@ -0,0 +1,24 @@ +package main + +import ( + "log" + + "github.com/radareorg/r2pipe-go" +) + +func r2cmd(r2p *r2pipe.Pipe, input string) (string, error) { + + log.Println("---") + log.Printf("> %s\n", input) + log.Println("---") + + // send a command + buf1, err := r2p.Cmd(input) + if err != nil { + log.Println(err) + return "", err + } + + // return the result of the command as a string + return buf1, nil +} diff --git a/src/sql.go b/src/sql.go new file mode 100644 index 0000000..858878b --- /dev/null +++ b/src/sql.go @@ -0,0 +1,31 @@ +package main + +import ( + "database/sql" + "database/sql/driver" + "modernc.org/sqlite" +) + +type sqlite3Driver struct { + *sqlite.Driver +} + +type sqlite3DriverConn interface { + Exec(string, []driver.Value) (driver.Result, error) +} + +func (d sqlite3Driver) Open(name string) (conn driver.Conn, err error) { + conn, err = d.Driver.Open(name) + if err != nil { + return + } + _, err = conn.(sqlite3DriverConn).Exec("PRAGMA foreign_keys = ON;", nil) + if err != nil { + _ = conn.Close() + } + return +} + +func init() { + sql.Register("sqlite3", sqlite3Driver{Driver: &sqlite.Driver{}}) +} diff --git a/src/sqlitestore.go b/src/sqlitestore.go new file mode 100644 index 0000000..6f59d15 --- /dev/null +++ b/src/sqlitestore.go @@ -0,0 +1,284 @@ +/* + Gorilla Sessions backend for SQLite. + +Copyright (c) 2013 Contributors. See the list of contributors in the CONTRIBUTORS file for details. + +This software is licensed under a MIT style license available in the LICENSE file. +*/ +package main + +import ( + "database/sql" + "encoding/gob" + "errors" + "fmt" + "log" + "net/http" + "strings" + "time" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" + _ "modernc.org/sqlite" +) + +type SqliteStore struct { + db DB + stmtInsert *sql.Stmt + stmtDelete *sql.Stmt + stmtUpdate *sql.Stmt + stmtSelect *sql.Stmt + + Codecs []securecookie.Codec + Options *sessions.Options + table string +} + +type sessionRow struct { + id string + data string + createdOn time.Time + modifiedOn time.Time + expiresOn time.Time +} + +type DB interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + Close() error +} + +func init() { + gob.Register(time.Time{}) +} + +func NewSqliteStore(endpoint string, tableName string, path string, maxAge int, keyPairs ...[]byte) (*SqliteStore, error) { + db, err := sql.Open("sqlite3", endpoint) + if err != nil { + return nil, err + } + + return NewSqliteStoreFromConnection(db, tableName, path, maxAge, keyPairs...) +} + +func NewSqliteStoreFromConnection(db DB, tableName string, path string, maxAge int, keyPairs ...[]byte) (*SqliteStore, error) { + // Make sure table name is enclosed. + tableName = "`" + strings.Trim(tableName, "`") + "`" + + cTableQ := "CREATE TABLE IF NOT EXISTS " + + tableName + " (id INTEGER PRIMARY KEY, " + + "session_data LONGBLOB, " + + "created_on TIMESTAMP DEFAULT 0, " + + "modified_on TIMESTAMP DEFAULT CURRENT_TIMESTAMP, " + + "expires_on TIMESTAMP DEFAULT 0);" + if _, err := db.Exec(cTableQ); err != nil { + return nil, err + } + + insQ := "INSERT INTO " + tableName + + "(id, session_data, created_on, modified_on, expires_on) VALUES (NULL, ?, ?, ?, ?)" + stmtInsert, stmtErr := db.Prepare(insQ) + if stmtErr != nil { + return nil, stmtErr + } + + delQ := "DELETE FROM " + tableName + " WHERE id = ?" + stmtDelete, stmtErr := db.Prepare(delQ) + if stmtErr != nil { + return nil, stmtErr + } + + updQ := "UPDATE " + tableName + " SET session_data = ?, created_on = ?, expires_on = ? " + + "WHERE id = ?" + stmtUpdate, stmtErr := db.Prepare(updQ) + if stmtErr != nil { + return nil, stmtErr + } + + selQ := "SELECT id, session_data, created_on, modified_on, expires_on from " + + tableName + " WHERE id = ?" + stmtSelect, stmtErr := db.Prepare(selQ) + if stmtErr != nil { + return nil, stmtErr + } + + return &SqliteStore{ + db: db, + stmtInsert: stmtInsert, + stmtDelete: stmtDelete, + stmtUpdate: stmtUpdate, + stmtSelect: stmtSelect, + Codecs: securecookie.CodecsFromPairs(keyPairs...), + Options: &sessions.Options{ + Path: path, + MaxAge: maxAge, + }, + table: tableName, + }, nil +} + +func (m *SqliteStore) Close() { + m.stmtSelect.Close() + m.stmtUpdate.Close() + m.stmtDelete.Close() + m.stmtInsert.Close() + m.db.Close() +} + +func (m *SqliteStore) Get(r *http.Request, name string) (*sessions.Session, error) { + return sessions.GetRegistry(r).Get(m, name) +} + +func (m *SqliteStore) New(r *http.Request, name string) (*sessions.Session, error) { + session := sessions.NewSession(m, name) + session.Options = &sessions.Options{ + Path: m.Options.Path, + MaxAge: m.Options.MaxAge, + } + session.IsNew = true + var err error + if cook, errCookie := r.Cookie(name); errCookie == nil { + err = securecookie.DecodeMulti(name, cook.Value, &session.ID, m.Codecs...) + if err == nil { + err = m.load(session) + if err == nil { + session.IsNew = false + } else { + err = nil + } + } + } + return session, err +} + +func (m *SqliteStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { + var err error + if session.ID == "" { + if err = m.insert(session); err != nil { + return err + } + } else if err = m.save(session); err != nil { + return err + } + encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, m.Codecs...) + if err != nil { + return err + } + http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options)) + return nil +} + +func (m *SqliteStore) insert(session *sessions.Session) error { + var createdOn time.Time + var modifiedOn time.Time + var expiresOn time.Time + crOn := session.Values["created_on"] + if crOn == nil { + createdOn = time.Now() + } else { + createdOn = crOn.(time.Time) + } + modifiedOn = createdOn + exOn := session.Values["expires_on"] + if exOn == nil { + expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge)) + } else { + expiresOn = exOn.(time.Time) + } + delete(session.Values, "created_on") + delete(session.Values, "expires_on") + delete(session.Values, "modified_on") + + encoded, encErr := securecookie.EncodeMulti(session.Name(), session.Values, m.Codecs...) + if encErr != nil { + return encErr + } + res, insErr := m.stmtInsert.Exec(encoded, createdOn, modifiedOn, expiresOn) + if insErr != nil { + return insErr + } + lastInserted, lInsErr := res.LastInsertId() + if lInsErr != nil { + return lInsErr + } + session.ID = fmt.Sprintf("%d", lastInserted) + return nil +} + +func (m *SqliteStore) Delete(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { + + // Set cookie to expire. + options := *session.Options + options.MaxAge = -1 + http.SetCookie(w, sessions.NewCookie(session.Name(), "", &options)) + // Clear session values. + for k := range session.Values { + delete(session.Values, k) + } + + _, delErr := m.stmtDelete.Exec(session.ID) + if delErr != nil { + return delErr + } + return nil +} + +func (m *SqliteStore) save(session *sessions.Session) error { + if session.IsNew == true { + return m.insert(session) + } + var createdOn time.Time + var expiresOn time.Time + crOn := session.Values["created_on"] + if crOn == nil { + createdOn = time.Now() + } else { + createdOn = crOn.(time.Time) + } + + exOn := session.Values["expires_on"] + if exOn == nil { + expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge)) + log.Print("nil") + } else { + expiresOn = exOn.(time.Time) + if expiresOn.Sub(time.Now().Add(time.Second*time.Duration(session.Options.MaxAge))) < 0 { + expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge)) + } + } + + delete(session.Values, "created_on") + delete(session.Values, "expires_on") + delete(session.Values, "modified_on") + encoded, encErr := securecookie.EncodeMulti(session.Name(), session.Values, m.Codecs...) + if encErr != nil { + return encErr + } + _, updErr := m.stmtUpdate.Exec(encoded, createdOn, expiresOn, session.ID) + if updErr != nil { + return updErr + } + return nil +} + +func (m *SqliteStore) load(session *sessions.Session) error { + row := m.stmtSelect.QueryRow(session.ID) + sess := sessionRow{} + scanErr := row.Scan(&sess.id, &sess.data, &sess.createdOn, &sess.modifiedOn, &sess.expiresOn) + if scanErr != nil { + return scanErr + } + if sess.expiresOn.Sub(time.Now()) < 0 { + log.Printf("Session expired on %s, but it is %s now.", sess.expiresOn, time.Now()) + return errors.New("Session expired") + } + err := securecookie.DecodeMulti(session.Name(), sess.data, &session.Values, m.Codecs...) + if err != nil { + return err + } + session.Values["created_on"] = sess.createdOn + session.Values["modified_on"] = sess.modifiedOn + session.Values["expires_on"] = sess.expiresOn + return nil + +} diff --git a/src/user.go b/src/user.go new file mode 100644 index 0000000..cc77657 --- /dev/null +++ b/src/user.go @@ -0,0 +1,666 @@ +package main + +import ( + "fmt" + "html/template" + "log" + "net/http" + "strconv" + "time" + + "github.com/gorilla/mux" + "golang.org/x/crypto/argon2" +) + +type User struct { + ID int + Name string + PasswordHash []byte +} + +////////////////////////////////////////////////////////////////////////////// +// GENERAL PURPOSE + +func UserRegister(username string, passwordHash []byte) (int, error) { + id, err := globalState.InsertUser(User{Name: username, PasswordHash: passwordHash}) + if err != nil { + return 0, err + } else { + return id, nil + } +} + +// UserCheckPasswordHash returns a boolean that is true if the users password +// is correct and false if the users password is false and thus doesn't match +// the one stored in the database +func UserCheckPasswordHash(username string, passwordHash []byte) bool { + return globalState.CheckUserHash(username, passwordHash) +} + +// UserUpdatePasswordHash does exactly that +func UserUpdatePasswordHash(orig_username string, passwordHash []byte) error { + return globalState.UpdateUserPasswordHash(orig_username, passwordHash) +} + +func UserUpdateUsername(id int, new_username string) error { + return globalState.UpdateUserUsername(id, new_username) +} + +func UserLinkBot(username string, botid int) error { + return globalState.LinkUserBot(username, botid) +} + +func UserGetBotsUsingUsername(username string) ([]Bot, error) { + return globalState.GetUserBotsUsername(username) +} + +func UserGetBotsUsingUserID(userid int) ([]Bot, error) { + return globalState.GetUserBotsId(userid) +} + +func UserGetUserFromID(userid int) (User, error) { + return globalState.GetUserFromId(userid) +} + +func UserGetUserFromUsername(username string) (User, error) { + return globalState.GetUserFromUsername(username) +} + +func UserGetAll() ([]User, error) { + return globalState.GetAllUsers() +} + +////////////////////////////////////////////////////////////////////////////// +// DATABASE + +func (s *State) InsertUser(user User) (int, error) { + res, err := s.db.Exec("INSERT INTO users VALUES(NULL,?,?,?);", time.Now(), user.Name, user.PasswordHash) + if err != nil { + return 0, err + } + + var id int64 + if id, err = res.LastInsertId(); err != nil { + return 0, err + } + return int(id), nil +} + +// returns true if the password matches +func (s *State) CheckUserHash(username string, passwordHash []byte) bool { + var created time.Time + err := s.db.QueryRow("SELECT created_at FROM users WHERE name=? AND passwordHash=?", username, passwordHash).Scan(&created) + switch { + case err != nil: + return false + default: + return true + } +} + +func (s *State) UpdateUserPasswordHash(username string, passwordHash []byte) error { + _, err := s.db.Exec("UPDATE users SET passwordHash=? WHERE name=?", passwordHash, username) + if err != nil { + return err + } else { + return nil + } +} + +func (s *State) UpdateUserUsername(id int, new_username string) error { + _, err := s.db.Exec("UPDATE users SET name=? WHERE id=?", new_username, id) + if err != nil { + return err + } else { + return nil + } +} + +// Links the given bot to the given user in the user_bot_rel table +func (s *State) LinkUserBot(username string, botid int) error { + _, err := s.db.Exec("INSERT INTO user_bot_rel VALUES ((SELECT id FROM users WHERE name=?), ?)", username, botid) + if err != nil { + return err + } else { + return nil + } +} + +// Links the given bot to the given user in the user_bot_rel table +func (s *State) GetUserFromId(id int) (User, error) { + var user_id int + var username string + err := s.db.QueryRow("SELECT id, name FROM users WHERE id=?", id).Scan(&user_id, &username) + if err != nil { + return User{}, err + } else { + return User{user_id, username, nil}, nil + } +} + +func (s *State) GetUserFromUsername(username string) (User, error) { + var id int + var name string + err := s.db.QueryRow("SELECT id, name FROM users WHERE name=?", username).Scan(&id, &name) + if err != nil { + return User{}, err + } else { + return User{id, name, nil}, nil + } +} + +// Returns the bots belonging to the given user +func (s *State) GetUserBotsUsername(username string) ([]Bot, error) { + rows, err := s.db.Query("SELECT id, name, source FROM bots b LEFT JOIN user_bot_rel ub ON ub.bot_id = b.id WHERE ub.user_id=(SELECT id FROM users WHERE name=?)", username) + defer rows.Close() + if err != nil { + return nil, err + } + + var bots []Bot + for rows.Next() { + var bot Bot + if err := rows.Scan(&bot.ID, &bot.Name, &bot.Source); err != nil { + return bots, err + } + bots = append(bots, bot) + } + if err = rows.Err(); err != nil { + return bots, err + } + return bots, nil +} + +// Returns the bots belonging to the given user +func (s *State) GetUserBotsId(id int) ([]Bot, error) { + rows, err := s.db.Query("SELECT id, name, source FROM bots b LEFT JOIN user_bot_rel ub ON ub.bot_id = b.id WHERE ub.user_id=?", id) + defer rows.Close() + if err != nil { + return nil, err + } + + var bots []Bot + for rows.Next() { + var bot Bot + if err := rows.Scan(&bot.ID, &bot.Name, &bot.Source); err != nil { + return bots, err + } + bots = append(bots, bot) + } + if err = rows.Err(); err != nil { + return bots, err + } + return bots, nil +} + +// Returns the bots belonging to the given user +func (s *State) GetAllUsers() ([]User, error) { + rows, err := s.db.Query("SELECT id, name FROM users") + defer rows.Close() + if err != nil { + return nil, err + } + + var users []User + for rows.Next() { + var user User + if err := rows.Scan(&user.ID, &user.Name); err != nil { + return users, err + } + users = append(users, user) + } + if err = rows.Err(); err != nil { + return users, err + } + return users, nil +} + +////////////////////////////////////////////////////////////////////////////// +// HTTP + +func loginHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "GET": + log.Println("GET /login") + // define data + log.Println("[d] Defining breadcrumbs") + data := map[string]interface{}{} + data["pagelink1"] = Link{"login", "/login"} + data["pagelink1options"] = []Link{ + {Name: "register", Target: "/register"}, + } + data["pagelinkauth"] = []Link{ + {Name: "register/", Target: "/register"}, + } + + // session foo + log.Println("[d] Getting session") + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"] + + // get the user + if username != nil { + log.Printf("[d] Getting the user %s\n", username.(string)) + user, err := UserGetUserFromUsername(username.(string)) + if user.Name == "" { + log.Println("no user found") + } else if err != nil { + log.Println(err) + msg := "Error: could not get the user for given username" + http.Redirect(w, r, fmt.Sprintf("/login?res=%s", msg), http.StatusSeeOther) + return + } else { + data["user"] = user + } + } + + // display errors passed via query parameters + log.Println("[d] Getting previous results") + queryres := r.URL.Query().Get("res") + if queryres != "" { + data["res"] = queryres + } + + // get the template + log.Println("[d] Getting the template") + t, err := template.ParseGlob("./templates/*.html") + if err != nil { + log.Println("Error parsing the login template: ", err) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // exec! + log.Println("[d] Executing the template") + t.ExecuteTemplate(w, "login", data) + + case "POST": + // parse the post parameters + r.ParseForm() + username := r.Form.Get("username") + password := r.Form.Get("password") + + // if we've got a password, hash it and compare it with the stored one + if password != "" { + passwordHash := argon2.IDKey([]byte(password), []byte(salt), 1, 64*1024, 4, 32) + + // check if it's valid + valid := UserCheckPasswordHash(username, passwordHash) + if valid { + + // if it's valid, we set a session for the user + session, _ := globalState.sessions.Get(r, "session") + session.Values["username"] = username + err := session.Save(r, w) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/", http.StatusSeeOther) + return + + } else { + // invalid password + http.Redirect(w, r, "/login?err=Invalid+Password", http.StatusSeeOther) + return + } + } else { + // empty password + http.Redirect(w, r, "/login?err=Empty+Password", http.StatusSeeOther) + return + } + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} + +func registerHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "GET": + // define data + data := map[string]interface{}{} + + // get the session + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"] + data["pagelink1"] = Link{"register", "/register"} + data["pagelink1options"] = []Link{ + {Name: "login", Target: "/login"}, + } + data["pagelinkauth"] = []Link{ + {Name: "login/", Target: "/login"}, + } + + log.Println(username) + + if username != nil { + data["logged_in"] = true + } + + // get the template + t, err := template.ParseGlob("./templates/*.html") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // exec! + t.ExecuteTemplate(w, "register", data) + + case "POST": + // parse the post parameters + r.ParseForm() + username := r.Form.Get("username") + password1 := r.Form.Get("password1") + password2 := r.Form.Get("password2") + + if len(username) >= 64 { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Oi', Backend here! Please enter less than 64 chars!")) + return + } + + if len(password1) >= 256 { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Oi', Backend here! Don't overdo with the length please!")) + return + } + + if password1 != password2 { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Oi', Backend here! The passwords you entered don't match!")) + return + } + + // if we've got a password, hash it and store it and create a User + if password1 != "" { + passwordHash := argon2.IDKey([]byte(password1), []byte(salt), 1, 64*1024, 4, 32) + + _, err := UserRegister(username, passwordHash) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - We had problems inserting you into the DB")) + return + } + + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} + +func logoutHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "POST": + session, _ := globalState.sessions.Get(r, "session") + session.Options.MaxAge = -1 + err := session.Save(r, w) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + http.Redirect(w, r, "/", http.StatusSeeOther) + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} + +func userHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + id, err := strconv.Atoi(vars["id"]) + if err != nil { + http.Redirect(w, r, "/user", http.StatusSeeOther) + } + + switch r.Method { + case "GET": + // define data + data := map[string]interface{}{} + data["pagelink1"] = Link{"user", "/user"} + data["pagelink1options"] = []Link{ + {Name: "bot", Target: "/bot"}, + {Name: "battle", Target: "/battle"}, + } + + // session foo + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"].(string) + + // the the user making the request + user, err := UserGetUserFromUsername(username) + if err != nil { + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } else { + data["user"] = user + } + + // get the target user using the provided id + targetUser, err := UserGetUserFromID(id) + if err != nil { + data["err"] = "Could not find that user" + } else { + data["targetUser"] = targetUser + } + + // define the breadcrumbs + data["pagelink2"] = Link{targetUser.Name, fmt.Sprintf("/%s", targetUser.Name)} + + allUserNames, err := UserGetAll() + var opts []Link + for _, user := range allUserNames { + opts = append(opts, Link{Name: user.Name, Target: fmt.Sprintf("/%d", user.ID)}) + } + data["pagelink2options"] = opts + + // get the bots for the given user + bots, err := UserGetBotsUsingUserID(id) + if err != nil { + http.Redirect(w, r, "/user", http.StatusSeeOther) + } else { + data["bots"] = bots + } + + // get the template + t, err := template.ParseGlob("./templates/*.html") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // exec! + t.ExecuteTemplate(w, "user", data) + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} + +func usersHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "GET": + // define data + data := map[string]interface{}{} + data["pagelink1"] = Link{Name: "user", Target: "/user"} + data["pagelink1options"] = []Link{ + {Name: "bot", Target: "/bot"}, + {Name: "battle", Target: "/battle"}, + } + + // sessions + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"].(string) + + // get the user + user, err := UserGetUserFromUsername(username) + if err != nil { + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } else { + data["user"] = user + } + + // get all users + users, err := UserGetAll() + data["users"] = users + + // get the template + t, err := template.ParseGlob("./templates/*.html") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // exec! + t.ExecuteTemplate(w, "users", data) + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} + +func profileHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + id, err := strconv.Atoi(vars["id"]) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + switch r.Method { + case "GET": + // define data + data := map[string]interface{}{} + data["pagelink1"] = Link{"user", "/user"} + data["pagelink1options"] = []Link{ + {Name: "bot", Target: "/bot"}, + {Name: "battle", Target: "/battle"}, + } + + session, _ := globalState.sessions.Get(r, "session") + username := session.Values["username"].(string) + + target_user, err := UserGetUserFromID(id) + if err != nil { + // w.WriteHeader(http.StatusUnauthorized) + // w.Write([]byte("500 - Error reading template file")) + // http.Error(w, err.Error(), http.StatusInternalServerError) + data["err"] = "Error getting with the given id" + } + + if username != target_user.Name { + // w.WriteHeader(http.StatusInternalServerError) + // w.Write([]byte("500 - Error reading template file")) + // http.Error(w, err.Error(), http.StatusInternalServerError) + data["err"] = "You aren't allowed to edit any user except yourself" + } + + editing_user, err := UserGetUserFromUsername(username) + if err != nil { + data["err"] = "Coulnd't get a user for that id" + } + + data["user"] = editing_user + data["target_user"] = target_user + + data["pagelink2"] = Link{target_user.Name, fmt.Sprintf("/%d", id)} + allUserNames, err := UserGetAll() + var opts []Link + for _, user := range allUserNames { + opts = append(opts, Link{Name: user.Name, Target: fmt.Sprintf("/%d", user.ID)}) + } + data["pagelink2options"] = opts + + data["pagelink3"] = Link{"profile", "/profile"} + + if username != "" { + data["username"] = username + } + + // get the template + t, err := template.ParseGlob("./templates/*.html") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Error reading template file")) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // exec! + t.ExecuteTemplate(w, "profile", data) + + case "POST": + session, _ := globalState.sessions.Get(r, "session") + orig_username := session.Values["username"].(string) + + // parse the post parameters + r.ParseForm() + new_username := r.Form.Get("username") + password1 := r.Form.Get("password1") + password2 := r.Form.Get("password2") + + if len(new_username) >= 64 { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Oi', Backend here! Please enter less than 64 chars!")) + return + } + + if len(password1) >= 256 { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Oi', Backend here! Don't overdo with the length please!")) + return + } + + if password1 != password2 { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - Oi', Backend here! The passwords you entered don't match!")) + return + } + + // first update the password, as they might have also changed their + // username + if password1 != "" { + passwordHash := argon2.IDKey([]byte(password1), []byte(salt), 1, 64*1024, 4, 32) + + err := UserUpdatePasswordHash(orig_username, passwordHash) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - We had problems inserting your new pw into the DB")) + return + } + } + + if new_username != "" { + err := UserUpdateUsername(id, new_username) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 - We had problems inserting your new uname into the DB")) + return + } + + // after changing the username, we also have to update the username + // stored in the session + session.Values["username"] = new_username + err = session.Save(r, w) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + + http.Redirect(w, r, fmt.Sprintf("/user/%d/profile", id), http.StatusSeeOther) + return + default: + http.Redirect(w, r, "/", http.StatusMethodNotAllowed) + } +} |