about summary refs log tree commit diff
path: root/src/sqlitestore.go
diff options
context:
space:
mode:
authorEmile <git@emile.space>2024-08-16 19:50:26 +0200
committerEmile <git@emile.space>2024-08-16 19:50:26 +0200
commit1a57267a17c2fc17fb6e104846fabc3e363c326c (patch)
tree1e574e3a80622086dc3c81ff9cba65ef7049b1a9 /src/sqlitestore.go
initial commit
Diffstat (limited to 'src/sqlitestore.go')
-rw-r--r--src/sqlitestore.go284
1 files changed, 284 insertions, 0 deletions
diff --git a/src/sqlitestore.go b/src/sqlitestore.go
new file mode 100644
index 0000000..6f59d15
--- /dev/null
+++ b/src/sqlitestore.go
@@ -0,0 +1,284 @@
+/*
+	Gorilla Sessions backend for SQLite.
+
+Copyright (c) 2013 Contributors. See the list of contributors in the CONTRIBUTORS file for details.
+
+This software is licensed under a MIT style license available in the LICENSE file.
+*/
+package main
+
+import (
+	"database/sql"
+	"encoding/gob"
+	"errors"
+	"fmt"
+	"log"
+	"net/http"
+	"strings"
+	"time"
+
+	"github.com/gorilla/securecookie"
+	"github.com/gorilla/sessions"
+	_ "modernc.org/sqlite"
+)
+
+type SqliteStore struct {
+	db         DB
+	stmtInsert *sql.Stmt
+	stmtDelete *sql.Stmt
+	stmtUpdate *sql.Stmt
+	stmtSelect *sql.Stmt
+
+	Codecs  []securecookie.Codec
+	Options *sessions.Options
+	table   string
+}
+
+type sessionRow struct {
+	id         string
+	data       string
+	createdOn  time.Time
+	modifiedOn time.Time
+	expiresOn  time.Time
+}
+
+type DB interface {
+	Exec(query string, args ...interface{}) (sql.Result, error)
+	Prepare(query string) (*sql.Stmt, error)
+	Close() error
+}
+
+func init() {
+	gob.Register(time.Time{})
+}
+
+func NewSqliteStore(endpoint string, tableName string, path string, maxAge int, keyPairs ...[]byte) (*SqliteStore, error) {
+	db, err := sql.Open("sqlite3", endpoint)
+	if err != nil {
+		return nil, err
+	}
+
+	return NewSqliteStoreFromConnection(db, tableName, path, maxAge, keyPairs...)
+}
+
+func NewSqliteStoreFromConnection(db DB, tableName string, path string, maxAge int, keyPairs ...[]byte) (*SqliteStore, error) {
+	// Make sure table name is enclosed.
+	tableName = "`" + strings.Trim(tableName, "`") + "`"
+
+	cTableQ := "CREATE TABLE IF NOT EXISTS " +
+		tableName + " (id INTEGER PRIMARY KEY, " +
+		"session_data LONGBLOB, " +
+		"created_on TIMESTAMP DEFAULT 0, " +
+		"modified_on TIMESTAMP DEFAULT CURRENT_TIMESTAMP, " +
+		"expires_on TIMESTAMP DEFAULT 0);"
+	if _, err := db.Exec(cTableQ); err != nil {
+		return nil, err
+	}
+
+	insQ := "INSERT INTO " + tableName +
+		"(id, session_data, created_on, modified_on, expires_on) VALUES (NULL, ?, ?, ?, ?)"
+	stmtInsert, stmtErr := db.Prepare(insQ)
+	if stmtErr != nil {
+		return nil, stmtErr
+	}
+
+	delQ := "DELETE FROM " + tableName + " WHERE id = ?"
+	stmtDelete, stmtErr := db.Prepare(delQ)
+	if stmtErr != nil {
+		return nil, stmtErr
+	}
+
+	updQ := "UPDATE " + tableName + " SET session_data = ?, created_on = ?, expires_on = ? " +
+		"WHERE id = ?"
+	stmtUpdate, stmtErr := db.Prepare(updQ)
+	if stmtErr != nil {
+		return nil, stmtErr
+	}
+
+	selQ := "SELECT id, session_data, created_on, modified_on, expires_on from " +
+		tableName + " WHERE id = ?"
+	stmtSelect, stmtErr := db.Prepare(selQ)
+	if stmtErr != nil {
+		return nil, stmtErr
+	}
+
+	return &SqliteStore{
+		db:         db,
+		stmtInsert: stmtInsert,
+		stmtDelete: stmtDelete,
+		stmtUpdate: stmtUpdate,
+		stmtSelect: stmtSelect,
+		Codecs:     securecookie.CodecsFromPairs(keyPairs...),
+		Options: &sessions.Options{
+			Path:   path,
+			MaxAge: maxAge,
+		},
+		table: tableName,
+	}, nil
+}
+
+func (m *SqliteStore) Close() {
+	m.stmtSelect.Close()
+	m.stmtUpdate.Close()
+	m.stmtDelete.Close()
+	m.stmtInsert.Close()
+	m.db.Close()
+}
+
+func (m *SqliteStore) Get(r *http.Request, name string) (*sessions.Session, error) {
+	return sessions.GetRegistry(r).Get(m, name)
+}
+
+func (m *SqliteStore) New(r *http.Request, name string) (*sessions.Session, error) {
+	session := sessions.NewSession(m, name)
+	session.Options = &sessions.Options{
+		Path:   m.Options.Path,
+		MaxAge: m.Options.MaxAge,
+	}
+	session.IsNew = true
+	var err error
+	if cook, errCookie := r.Cookie(name); errCookie == nil {
+		err = securecookie.DecodeMulti(name, cook.Value, &session.ID, m.Codecs...)
+		if err == nil {
+			err = m.load(session)
+			if err == nil {
+				session.IsNew = false
+			} else {
+				err = nil
+			}
+		}
+	}
+	return session, err
+}
+
+func (m *SqliteStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
+	var err error
+	if session.ID == "" {
+		if err = m.insert(session); err != nil {
+			return err
+		}
+	} else if err = m.save(session); err != nil {
+		return err
+	}
+	encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, m.Codecs...)
+	if err != nil {
+		return err
+	}
+	http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
+	return nil
+}
+
+func (m *SqliteStore) insert(session *sessions.Session) error {
+	var createdOn time.Time
+	var modifiedOn time.Time
+	var expiresOn time.Time
+	crOn := session.Values["created_on"]
+	if crOn == nil {
+		createdOn = time.Now()
+	} else {
+		createdOn = crOn.(time.Time)
+	}
+	modifiedOn = createdOn
+	exOn := session.Values["expires_on"]
+	if exOn == nil {
+		expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge))
+	} else {
+		expiresOn = exOn.(time.Time)
+	}
+	delete(session.Values, "created_on")
+	delete(session.Values, "expires_on")
+	delete(session.Values, "modified_on")
+
+	encoded, encErr := securecookie.EncodeMulti(session.Name(), session.Values, m.Codecs...)
+	if encErr != nil {
+		return encErr
+	}
+	res, insErr := m.stmtInsert.Exec(encoded, createdOn, modifiedOn, expiresOn)
+	if insErr != nil {
+		return insErr
+	}
+	lastInserted, lInsErr := res.LastInsertId()
+	if lInsErr != nil {
+		return lInsErr
+	}
+	session.ID = fmt.Sprintf("%d", lastInserted)
+	return nil
+}
+
+func (m *SqliteStore) Delete(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
+
+	// Set cookie to expire.
+	options := *session.Options
+	options.MaxAge = -1
+	http.SetCookie(w, sessions.NewCookie(session.Name(), "", &options))
+	// Clear session values.
+	for k := range session.Values {
+		delete(session.Values, k)
+	}
+
+	_, delErr := m.stmtDelete.Exec(session.ID)
+	if delErr != nil {
+		return delErr
+	}
+	return nil
+}
+
+func (m *SqliteStore) save(session *sessions.Session) error {
+	if session.IsNew == true {
+		return m.insert(session)
+	}
+	var createdOn time.Time
+	var expiresOn time.Time
+	crOn := session.Values["created_on"]
+	if crOn == nil {
+		createdOn = time.Now()
+	} else {
+		createdOn = crOn.(time.Time)
+	}
+
+	exOn := session.Values["expires_on"]
+	if exOn == nil {
+		expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge))
+		log.Print("nil")
+	} else {
+		expiresOn = exOn.(time.Time)
+		if expiresOn.Sub(time.Now().Add(time.Second*time.Duration(session.Options.MaxAge))) < 0 {
+			expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge))
+		}
+	}
+
+	delete(session.Values, "created_on")
+	delete(session.Values, "expires_on")
+	delete(session.Values, "modified_on")
+	encoded, encErr := securecookie.EncodeMulti(session.Name(), session.Values, m.Codecs...)
+	if encErr != nil {
+		return encErr
+	}
+	_, updErr := m.stmtUpdate.Exec(encoded, createdOn, expiresOn, session.ID)
+	if updErr != nil {
+		return updErr
+	}
+	return nil
+}
+
+func (m *SqliteStore) load(session *sessions.Session) error {
+	row := m.stmtSelect.QueryRow(session.ID)
+	sess := sessionRow{}
+	scanErr := row.Scan(&sess.id, &sess.data, &sess.createdOn, &sess.modifiedOn, &sess.expiresOn)
+	if scanErr != nil {
+		return scanErr
+	}
+	if sess.expiresOn.Sub(time.Now()) < 0 {
+		log.Printf("Session expired on %s, but it is %s now.", sess.expiresOn, time.Now())
+		return errors.New("Session expired")
+	}
+	err := securecookie.DecodeMulti(session.Name(), sess.data, &session.Values, m.Codecs...)
+	if err != nil {
+		return err
+	}
+	session.Values["created_on"] = sess.createdOn
+	session.Values["modified_on"] = sess.modifiedOn
+	session.Values["expires_on"] = sess.expiresOn
+	return nil
+
+}