diff options
-rw-r--r-- | db_actions.go | 77 | ||||
-rw-r--r-- | db_actions_test.go | 149 |
2 files changed, 205 insertions, 21 deletions
diff --git a/db_actions.go b/db_actions.go index eec7a93..59cadef 100644 --- a/db_actions.go +++ b/db_actions.go @@ -38,7 +38,8 @@ const ( ) var ( - db *sql.DB + db *sql.DB + treeWidth float64 ) // connectToDB returns a pointer to an sql database writing to the database @@ -62,6 +63,10 @@ func dbConnect(connStr string) *sql.DB { // newTree creates a new tree with the given width func NewTree(database *sql.DB, width float64) { db = database + treeWidth = width + + log.Printf("Creating a new tree with a width of %f", width) + // get the current max root id query := fmt.Sprintf("SELECT COALESCE(max(root_id), 0) FROM nodes") var currentMaxRootID int64 @@ -71,7 +76,7 @@ func NewTree(database *sql.DB, width float64) { } // build the query creating a new node - query = fmt.Sprintf("INSERT INTO nodes (box_width, root_id, box_center, depth, isleaf) VALUES (%f, %d, '{0, 0}', 0, TRUE)", width, currentMaxRootID+1) + query = fmt.Sprintf("INSERT INTO nodes (box_width, root_id, box_center, depth, isleaf, timestep) VALUES (%f, %d, '{0, 0}', 0, TRUE, %d)", width, currentMaxRootID+1, currentMaxRootID+1) // execute the query rows, err := db.Query(query) @@ -85,17 +90,29 @@ func NewTree(database *sql.DB, width float64) { func InsertStar(database *sql.DB, star structs.Star2D, index int64) { db = database start := time.Now() + + log.Printf("Inserting the star %v into the tree with the index %d", star, index) + // insert the star into the stars table starID := insertIntoStars(star) // get the root node id - query := fmt.Sprintf("SELECT node_id FROM nodes WHERE root_id=%d", index) + query := fmt.Sprintf("select case when exists (select node_id from nodes where root_id=%d) then (select node_id from nodes where root_id=%d) else -1 end;", index, index) var id int64 err := db.QueryRow(query).Scan(&id) + + // if there are no rows in the result set, create a new tree if err != nil { log.Fatalf("[ E ] Get root node id query: %v\n\t\t\t query: %s\n", err, query) } + if id == -1 { + NewTree(db, 1000) + id = getRootNodeID(index) + } + + log.Printf("Node id of the root node %d: %d", id, index) + // insert the star into the tree (using it's ID) starting at the root insertIntoTree(starID, id) elapsedTime := time.Since(start) @@ -262,6 +279,8 @@ func subdivide(nodeID int64) { boxWidth := getBoxWidth(nodeID) boxCenter := getBoxCenter(nodeID) originalDepth := getNodeDepth(nodeID) + timestep := getTimestepNode(nodeID) + log.Printf("Subdividing %d, setting the timestep to %d", nodeID, timestep) // calculate the new positions newPosX := boxCenter[0] + (boxWidth / 2) @@ -271,15 +290,15 @@ func subdivide(nodeID int64) { newWidth := boxWidth / 2 // create new news with those positions - newNodeIDA := newNode(newPosX, newPosY, newWidth, originalDepth+1) - newNodeIDB := newNode(newPosX, newNegY, newWidth, originalDepth+1) - newNodeIDC := newNode(newNegX, newPosY, newWidth, originalDepth+1) - newNodeIDD := newNode(newNegX, newNegY, newWidth, originalDepth+1) + newNodeIDA := newNode(newPosX, newPosY, newWidth, originalDepth+1, timestep) + newNodeIDB := newNode(newPosX, newNegY, newWidth, originalDepth+1, timestep) + newNodeIDC := newNode(newNegX, newPosY, newWidth, originalDepth+1, timestep) + newNodeIDD := newNode(newNegX, newNegY, newWidth, originalDepth+1, timestep) // Update the subtrees of the parent node // build the query - query := fmt.Sprintf("UPDATE nodes SET subnode='{%d, %d, %d, %d}', isleaf=FALSE WHERE node_id=%d", newNodeIDA, newNodeIDB, newNodeIDC, newNodeIDD, nodeID) + query := fmt.Sprintf("UPDATE nodes SET subnode='{%d, %d, %d, %d}', isleaf=FALSE, timestep=%d WHERE node_id=%d", newNodeIDA, newNodeIDB, newNodeIDC, newNodeIDD, timestep, nodeID) // Execute the query rows, err := db.Query(query) @@ -302,6 +321,19 @@ func getBoxWidth(nodeID int64) float64 { return boxWidth } +// getTimestepNode gets the timestep of the current node +func getTimestepNode(nodeID int64) int64 { + var timestep int64 + + query := fmt.Sprintf("SELECT timestep FROM nodes WHERE node_id=%d", nodeID) + err := db.QueryRow(query).Scan(×tep) + if err != nil { + log.Fatalf("[ E ] getTimeStep query: %v\n\t\t\t query: %s\n", err, query) + } + + return timestep +} + // getBoxWidth gets the center of the box from the node width the given id func getBoxCenter(nodeID int64) []float64 { var boxCenterX, boxCenterY []uint8 @@ -325,10 +357,23 @@ func getBoxCenter(nodeID int64) []float64 { return boxCenterFloat } +// getMaxTimestep gets the maximal timestep from the nodes table +func getMaxTimestep() float64 { + var maxTimestep float64 + + query := fmt.Sprintf("SELECT max(timestep) FROM nodes") + err := db.QueryRow(query).Scan(&maxTimestep) + if err != nil { + log.Fatalf("[ E ] getMaxTimestep query: %v\n\t\t\t query: %s\n", err, query) + } + + return maxTimestep +} + // newNode Inserts a new node into the database with the given parameters -func newNode(x float64, y float64, width float64, depth int64) int64 { +func newNode(x float64, y float64, width float64, depth int64, timestep int64) int64 { // build the query creating a new node - query := fmt.Sprintf("INSERT INTO nodes (box_center, box_width, depth, isleaf) VALUES ('{%f, %f}', %f, %d, TRUE) RETURNING node_id", x, y, width, depth) + query := fmt.Sprintf("INSERT INTO nodes (box_center, box_width, depth, isleaf, timestep) VALUES ('{%f, %f}', %f, %d, TRUE, %d) RETURNING node_id", x, y, width, depth, timestep) var nodeID int64 @@ -936,14 +981,14 @@ func CalcAllForces(database *sql.DB, star structs.Star2D, galaxyIndex int64, the // calculate all the forces and add them to the list of all forces // this is done recursively // first of all, get the root id - log.Println("getting the root ID") + log.Println("[db_actions] Getting the root ID") rootID := getRootNodeID(galaxyIndex) - log.Println("done getting the root ID") + log.Println("[db_actions] Done getting the root ID") - log.Printf("Calculating the forces acting on the star %v", star) + log.Printf("[db_actions] Calculating the forces acting on the star %v", star) force := CalcAllForcesNode(star, rootID, theta) - log.Printf("Done calculating the forces acting on the star %v", star) - log.Printf("Force: %v", force) + log.Printf("[db_actions] Done calculating the forces acting on the star %v", star) + log.Printf("[db_actions] Force: %v", force) return force } @@ -952,7 +997,7 @@ func CalcAllForces(database *sql.DB, star structs.Star2D, galaxyIndex int64, the // TODO: implement the calcForce(star, centerOfMass) {...} function // TODO: implement the getSubtreeIDs(nodeID) []int64 {...} function func CalcAllForcesNode(star structs.Star2D, nodeID int64, theta float64) structs.Vec2 { - fmt.Println("-----------------------------------------------------------") + log.Println("---------------------------------------") log.Printf("NodeID: %d \t star: %v \t theta: %f \t nodeboxwidth: %f", nodeID, star, theta, getBoxWidth(nodeID)) var forceX float64 var forceY float64 diff --git a/db_actions_test.go b/db_actions_test.go index 7b96426..b4baec9 100644 --- a/db_actions_test.go +++ b/db_actions_test.go @@ -31,9 +31,10 @@ func TestCalcAllForces(t *testing.T) { db.SetMaxOpenConns(75) type args struct { - database *sql.DB - star structs.Star2D - theta float64 + database *sql.DB + star structs.Star2D + galaxyIndex int64 + theta float64 } tests := []struct { name string @@ -41,7 +42,7 @@ func TestCalcAllForces(t *testing.T) { want structs.Vec2 }{ { - name: "force acting on a single star", + name: "star in the top right quadrant", args: args{ database: db, star: structs.Star2D{ @@ -62,12 +63,150 @@ func TestCalcAllForces(t *testing.T) { Y: 0, }, }, + { + name: "star in the bottom left quadrant", + args: args{ + database: db, + star: structs.Star2D{ + C: structs.Vec2{ + X: -100, + Y: -100, + }, + V: structs.Vec2{ + X: 0, + Y: 0, + }, + M: 1000, + }, + theta: 0.5, + }, + want: structs.Vec2{ + X: 0, + Y: 0, + }, + }, + { + name: "star in the far top right quadrant", + args: args{ + database: db, + star: structs.Star2D{ + C: structs.Vec2{ + X: 490, + Y: 490, + }, + V: structs.Vec2{ + X: 0, + Y: 0, + }, + M: 1000, + }, + theta: 0.5, + }, + want: structs.Vec2{ + X: 0, + Y: 0, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := CalcAllForces(tt.args.database, tt.args.star, tt.args.theta); !reflect.DeepEqual(got, tt.want) { + if got := CalcAllForces(tt.args.database, tt.args.star, tt.args.galaxyIndex, tt.args.theta); !reflect.DeepEqual(got, tt.want) { t.Errorf("CalcAllForces() = %v, want %v", got, tt.want) } }) } } + +func TestInsertStar(t *testing.T) { + // define a database + db = ConnectToDB() + db.SetMaxOpenConns(75) + + type args struct { + database *sql.DB + star structs.Star2D + index int64 + } + tests := []struct { + name string + args args + }{ + { + name: "Insert a star into the database", + args: args{ + database: db, + star: structs.Star2D{ + C: structs.Vec2{ + X: 100, + Y: 100, + }, + V: structs.Vec2{ + X: 0, + Y: 0, + }, + M: 1000, + }, + index: 1, + }, + }, + { + name: "Insert a star into the database", + args: args{ + database: db, + star: structs.Star2D{ + C: structs.Vec2{ + X: 150, + Y: 150, + }, + V: structs.Vec2{ + X: 0, + Y: 0, + }, + M: 1000, + }, + index: 1, + }, + }, + { + name: "Insert a star into the database", + args: args{ + database: db, + star: structs.Star2D{ + C: structs.Vec2{ + X: 150, + Y: 150, + }, + V: structs.Vec2{ + X: 0, + Y: 0, + }, + M: 1000, + }, + index: 2, + }, + }, + { + name: "Insert a star into the database", + args: args{ + database: db, + star: structs.Star2D{ + C: structs.Vec2{ + X: 100, + Y: 100, + }, + V: structs.Vec2{ + X: 0, + Y: 0, + }, + M: 1000, + }, + index: 2, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + InsertStar(tt.args.database, tt.args.star, tt.args.index) + }) + } +} |