about summary refs log tree commit diff
diff options
context:
space:
mode:
authorEmile <hanemile@protonmail.com>2019-02-15 01:01:26 +0100
committerEmile <hanemile@protonmail.com>2019-02-15 01:01:26 +0100
commitbe7534abdb7d99e3d0229be4e6189dbbd81ef64b (patch)
tree90ca029b7e09b58507edc7dfdf5cef9ebacae662
parentc69573126b2d9fdb907b94ecab754417b95b842a (diff)
Fixed inserting a star into a specified timestep
-rw-r--r--db_actions.go77
-rw-r--r--db_actions_test.go149
2 files changed, 205 insertions, 21 deletions
diff --git a/db_actions.go b/db_actions.go
index eec7a93..59cadef 100644
--- a/db_actions.go
+++ b/db_actions.go
@@ -38,7 +38,8 @@ const (
 )
 
 var (
-	db *sql.DB
+	db        *sql.DB
+	treeWidth float64
 )
 
 // connectToDB returns a pointer to an sql database writing to the database
@@ -62,6 +63,10 @@ func dbConnect(connStr string) *sql.DB {
 // newTree creates a new tree with the given width
 func NewTree(database *sql.DB, width float64) {
 	db = database
+	treeWidth = width
+
+	log.Printf("Creating a new tree with a width of %f", width)
+
 	// get the current max root id
 	query := fmt.Sprintf("SELECT COALESCE(max(root_id), 0) FROM nodes")
 	var currentMaxRootID int64
@@ -71,7 +76,7 @@ func NewTree(database *sql.DB, width float64) {
 	}
 
 	// build the query creating a new node
-	query = fmt.Sprintf("INSERT INTO nodes (box_width, root_id, box_center, depth, isleaf) VALUES (%f, %d, '{0, 0}', 0, TRUE)", width, currentMaxRootID+1)
+	query = fmt.Sprintf("INSERT INTO nodes (box_width, root_id, box_center, depth, isleaf, timestep) VALUES (%f, %d, '{0, 0}', 0, TRUE, %d)", width, currentMaxRootID+1, currentMaxRootID+1)
 
 	// execute the query
 	rows, err := db.Query(query)
@@ -85,17 +90,29 @@ func NewTree(database *sql.DB, width float64) {
 func InsertStar(database *sql.DB, star structs.Star2D, index int64) {
 	db = database
 	start := time.Now()
+
+	log.Printf("Inserting the star %v into the tree with the index %d", star, index)
+
 	// insert the star into the stars table
 	starID := insertIntoStars(star)
 
 	// get the root node id
-	query := fmt.Sprintf("SELECT node_id FROM nodes WHERE root_id=%d", index)
+	query := fmt.Sprintf("select case when exists (select node_id from nodes where root_id=%d) then (select node_id from nodes where root_id=%d) else -1 end;", index, index)
 	var id int64
 	err := db.QueryRow(query).Scan(&id)
+
+	// if there are no rows in the result set, create a new tree
 	if err != nil {
 		log.Fatalf("[ E ] Get root node id query: %v\n\t\t\t query: %s\n", err, query)
 	}
 
+	if id == -1 {
+		NewTree(db, 1000)
+		id = getRootNodeID(index)
+	}
+
+	log.Printf("Node id of the root node %d: %d", id, index)
+
 	// insert the star into the tree (using it's ID) starting at the root
 	insertIntoTree(starID, id)
 	elapsedTime := time.Since(start)
@@ -262,6 +279,8 @@ func subdivide(nodeID int64) {
 	boxWidth := getBoxWidth(nodeID)
 	boxCenter := getBoxCenter(nodeID)
 	originalDepth := getNodeDepth(nodeID)
+	timestep := getTimestepNode(nodeID)
+	log.Printf("Subdividing %d, setting the timestep to %d", nodeID, timestep)
 
 	// calculate the new positions
 	newPosX := boxCenter[0] + (boxWidth / 2)
@@ -271,15 +290,15 @@ func subdivide(nodeID int64) {
 	newWidth := boxWidth / 2
 
 	// create new news with those positions
-	newNodeIDA := newNode(newPosX, newPosY, newWidth, originalDepth+1)
-	newNodeIDB := newNode(newPosX, newNegY, newWidth, originalDepth+1)
-	newNodeIDC := newNode(newNegX, newPosY, newWidth, originalDepth+1)
-	newNodeIDD := newNode(newNegX, newNegY, newWidth, originalDepth+1)
+	newNodeIDA := newNode(newPosX, newPosY, newWidth, originalDepth+1, timestep)
+	newNodeIDB := newNode(newPosX, newNegY, newWidth, originalDepth+1, timestep)
+	newNodeIDC := newNode(newNegX, newPosY, newWidth, originalDepth+1, timestep)
+	newNodeIDD := newNode(newNegX, newNegY, newWidth, originalDepth+1, timestep)
 
 	// Update the subtrees of the parent node
 
 	// build the query
-	query := fmt.Sprintf("UPDATE nodes SET subnode='{%d, %d, %d, %d}', isleaf=FALSE WHERE node_id=%d", newNodeIDA, newNodeIDB, newNodeIDC, newNodeIDD, nodeID)
+	query := fmt.Sprintf("UPDATE nodes SET subnode='{%d, %d, %d, %d}', isleaf=FALSE, timestep=%d WHERE node_id=%d", newNodeIDA, newNodeIDB, newNodeIDC, newNodeIDD, timestep, nodeID)
 
 	// Execute the query
 	rows, err := db.Query(query)
@@ -302,6 +321,19 @@ func getBoxWidth(nodeID int64) float64 {
 	return boxWidth
 }
 
+// getTimestepNode gets the timestep of the current node
+func getTimestepNode(nodeID int64) int64 {
+	var timestep int64
+
+	query := fmt.Sprintf("SELECT timestep FROM nodes WHERE node_id=%d", nodeID)
+	err := db.QueryRow(query).Scan(&timestep)
+	if err != nil {
+		log.Fatalf("[ E ] getTimeStep query: %v\n\t\t\t query: %s\n", err, query)
+	}
+
+	return timestep
+}
+
 // getBoxWidth gets the center of the box from the node width the given id
 func getBoxCenter(nodeID int64) []float64 {
 	var boxCenterX, boxCenterY []uint8
@@ -325,10 +357,23 @@ func getBoxCenter(nodeID int64) []float64 {
 	return boxCenterFloat
 }
 
+// getMaxTimestep gets the maximal timestep from the nodes table
+func getMaxTimestep() float64 {
+	var maxTimestep float64
+
+	query := fmt.Sprintf("SELECT max(timestep) FROM nodes")
+	err := db.QueryRow(query).Scan(&maxTimestep)
+	if err != nil {
+		log.Fatalf("[ E ] getMaxTimestep query: %v\n\t\t\t query: %s\n", err, query)
+	}
+
+	return maxTimestep
+}
+
 // newNode Inserts a new node into the database with the given parameters
-func newNode(x float64, y float64, width float64, depth int64) int64 {
+func newNode(x float64, y float64, width float64, depth int64, timestep int64) int64 {
 	// build the query creating a new node
-	query := fmt.Sprintf("INSERT INTO nodes (box_center, box_width, depth, isleaf) VALUES ('{%f, %f}', %f, %d, TRUE) RETURNING node_id", x, y, width, depth)
+	query := fmt.Sprintf("INSERT INTO nodes (box_center, box_width, depth, isleaf, timestep) VALUES ('{%f, %f}', %f, %d, TRUE, %d) RETURNING node_id", x, y, width, depth, timestep)
 
 	var nodeID int64
 
@@ -936,14 +981,14 @@ func CalcAllForces(database *sql.DB, star structs.Star2D, galaxyIndex int64, the
 	// calculate all the forces and add them to the list of all forces
 	// this is done recursively
 	// first of all, get the root id
-	log.Println("getting the root ID")
+	log.Println("[db_actions] Getting the root ID")
 	rootID := getRootNodeID(galaxyIndex)
-	log.Println("done getting the root ID")
+	log.Println("[db_actions] Done getting the root ID")
 
-	log.Printf("Calculating the forces acting on the star %v", star)
+	log.Printf("[db_actions] Calculating the forces acting on the star %v", star)
 	force := CalcAllForcesNode(star, rootID, theta)
-	log.Printf("Done calculating the forces acting on the star %v", star)
-	log.Printf("Force: %v", force)
+	log.Printf("[db_actions] Done calculating the forces acting on the star %v", star)
+	log.Printf("[db_actions] Force: %v", force)
 
 	return force
 }
@@ -952,7 +997,7 @@ func CalcAllForces(database *sql.DB, star structs.Star2D, galaxyIndex int64, the
 // TODO: implement the calcForce(star, centerOfMass) {...} function
 // TODO: implement the getSubtreeIDs(nodeID) []int64 {...} function
 func CalcAllForcesNode(star structs.Star2D, nodeID int64, theta float64) structs.Vec2 {
-	fmt.Println("-----------------------------------------------------------")
+	log.Println("---------------------------------------")
 	log.Printf("NodeID: %d \t star: %v \t theta: %f \t nodeboxwidth: %f", nodeID, star, theta, getBoxWidth(nodeID))
 	var forceX float64
 	var forceY float64
diff --git a/db_actions_test.go b/db_actions_test.go
index 7b96426..b4baec9 100644
--- a/db_actions_test.go
+++ b/db_actions_test.go
@@ -31,9 +31,10 @@ func TestCalcAllForces(t *testing.T) {
 	db.SetMaxOpenConns(75)
 
 	type args struct {
-		database *sql.DB
-		star     structs.Star2D
-		theta    float64
+		database    *sql.DB
+		star        structs.Star2D
+		galaxyIndex int64
+		theta       float64
 	}
 	tests := []struct {
 		name string
@@ -41,7 +42,7 @@ func TestCalcAllForces(t *testing.T) {
 		want structs.Vec2
 	}{
 		{
-			name: "force acting on a single star",
+			name: "star in the top right quadrant",
 			args: args{
 				database: db,
 				star: structs.Star2D{
@@ -62,12 +63,150 @@ func TestCalcAllForces(t *testing.T) {
 				Y: 0,
 			},
 		},
+		{
+			name: "star in the bottom left quadrant",
+			args: args{
+				database: db,
+				star: structs.Star2D{
+					C: structs.Vec2{
+						X: -100,
+						Y: -100,
+					},
+					V: structs.Vec2{
+						X: 0,
+						Y: 0,
+					},
+					M: 1000,
+				},
+				theta: 0.5,
+			},
+			want: structs.Vec2{
+				X: 0,
+				Y: 0,
+			},
+		},
+		{
+			name: "star in the far top right quadrant",
+			args: args{
+				database: db,
+				star: structs.Star2D{
+					C: structs.Vec2{
+						X: 490,
+						Y: 490,
+					},
+					V: structs.Vec2{
+						X: 0,
+						Y: 0,
+					},
+					M: 1000,
+				},
+				theta: 0.5,
+			},
+			want: structs.Vec2{
+				X: 0,
+				Y: 0,
+			},
+		},
 	}
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			if got := CalcAllForces(tt.args.database, tt.args.star, tt.args.theta); !reflect.DeepEqual(got, tt.want) {
+			if got := CalcAllForces(tt.args.database, tt.args.star, tt.args.galaxyIndex, tt.args.theta); !reflect.DeepEqual(got, tt.want) {
 				t.Errorf("CalcAllForces() = %v, want %v", got, tt.want)
 			}
 		})
 	}
 }
+
+func TestInsertStar(t *testing.T) {
+	// define a database
+	db = ConnectToDB()
+	db.SetMaxOpenConns(75)
+
+	type args struct {
+		database *sql.DB
+		star     structs.Star2D
+		index    int64
+	}
+	tests := []struct {
+		name string
+		args args
+	}{
+		{
+			name: "Insert a star into the database",
+			args: args{
+				database: db,
+				star: structs.Star2D{
+					C: structs.Vec2{
+						X: 100,
+						Y: 100,
+					},
+					V: structs.Vec2{
+						X: 0,
+						Y: 0,
+					},
+					M: 1000,
+				},
+				index: 1,
+			},
+		},
+		{
+			name: "Insert a star into the database",
+			args: args{
+				database: db,
+				star: structs.Star2D{
+					C: structs.Vec2{
+						X: 150,
+						Y: 150,
+					},
+					V: structs.Vec2{
+						X: 0,
+						Y: 0,
+					},
+					M: 1000,
+				},
+				index: 1,
+			},
+		},
+		{
+			name: "Insert a star into the database",
+			args: args{
+				database: db,
+				star: structs.Star2D{
+					C: structs.Vec2{
+						X: 150,
+						Y: 150,
+					},
+					V: structs.Vec2{
+						X: 0,
+						Y: 0,
+					},
+					M: 1000,
+				},
+				index: 2,
+			},
+		},
+		{
+			name: "Insert a star into the database",
+			args: args{
+				database: db,
+				star: structs.Star2D{
+					C: structs.Vec2{
+						X: 100,
+						Y: 100,
+					},
+					V: structs.Vec2{
+						X: 0,
+						Y: 0,
+					},
+					M: 1000,
+				},
+				index: 2,
+			},
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			InsertStar(tt.args.database, tt.args.star, tt.args.index)
+		})
+	}
+}