Commit d938d3a

Richard Luby <richluby@gmail.com>
2017-01-26 06:55:47
added db files
added ability to connect to database
1 parent 71eb4cb
Changed files (2)
db.go
@@ -0,0 +1,83 @@
+// db holds the functions relating to the database
+package main
+
+import (
+	"gopkg.in/mgo.v2"
+	"gopkg.in/mgo.v2/bson"
+	"strconv"
+)
+
+// DB_CONFIG defines the interactions within the database
+type DB_CONFIG struct {
+	// DB_name is the name of the database
+	DB_name,
+	// CollectionName is the name of the collection within the database
+	CollectionName string
+}
+
+// DB_config stores the configuration data for the database
+var DB_config = DB_CONFIG{
+	DB_name:        "records",
+	CollectionName: "cyber"}
+
+// DBsession contains the session information for the connection
+// to the database
+var DBsession *mgo.Session
+
+// initDBSession initializes the connection to the database containing
+// the list of records.
+// returns the *mgo.Session connected to the database
+func InitDBSession() (*mgo.Session, error) {
+	var err error
+	if serverConfig.DB_PORT != 0 {
+		DBsession, err = mgo.Dial(serverConfig.DB_ADDRESS + ":" + strconv.Itoa(serverConfig.DB_PORT))
+	}
+	if err != nil {
+		return nil, err
+	}
+	index := mgo.Index{Key: []string{"categorystring"},
+		DropDups: true,
+		Unique:   true}
+	DBsession.DB(DB_config.DB_name).C(DB_config.CollectionName).EnsureIndex(index)
+	return DBsession, nil
+}
+
+// loadRecordsIntoDB takes the records in memory and loads them into the database
+func loadRecordsIntoDB() {
+	for _, category := range categories {
+		for _, record := range category.RecordArray {
+			addRecordToDB(&record)
+		}
+		for _, record := range category.UsedRecordArray {
+			addRecordToDB(&record)
+		}
+	}
+}
+
+// returnForQuery
+func returnForQuery(selectedCategories []string, numQuestions int) ([]Record, error) {
+	collection := DBsession.DB(DB_config.DB_name).C(DB_config.CollectionName)
+	var results []Record
+	var collector []Record
+	var err error
+	for _, category := range selectedCategories {
+		err = collection.Find(bson.M{"categorystring": category}).All(&collector)
+		if err != nil {
+			return nil, err
+		} else {
+			results = append(results, collector...)
+		}
+	}
+	return results, nil
+}
+
+// addRecordToDB
+func addRecordToDB(record *Record) error {
+	if DBsession == nil {
+		if _, err := InitDBSession(); err != nil {
+			return err
+		}
+	}
+	err := DBsession.DB(DB_config.DB_name).C(DB_config.CollectionName).Insert(record)
+	return err
+}
db_test.go
@@ -0,0 +1,53 @@
+// tests the database connections
+package main
+
+import (
+	"gopkg.in/mgo.v2"
+	"strings"
+	"testing"
+)
+
+// initDBForTest creates a session for the database testing
+func initDBForTest(t *testing.T) *mgo.Session {
+	serverConfig = SERVER_CONFIG{DB_ADDRESS: "127.0.0.1",
+		DB_PORT: 27017}
+	session, err := InitDBSession()
+	if err != nil {
+		t.Errorf("Error for initDB: %+v", err)
+		t.FailNow()
+	}
+	return session
+}
+
+// Test_AddAndRetriveForDB
+func Test_AddAndRetriveForDB(t *testing.T) {
+	t.SkipNow()
+	DBsession = initDBForTest(t)
+	var recordTest = Record{Question: "Question",
+		Reference: "Reference",
+		Answer:    "Answer",
+		Path:      ""}
+	if err := addRecordToDB(&recordTest); err != nil {
+		if !strings.Contains(err.Error(), "duplicate key error") {
+			t.Errorf("Error for adding record: %+v", err)
+		}
+	}
+	results, err := returnForQuery([]string{"CategoryString"}, 1)
+	if err != nil {
+		t.Errorf("Error for returning initial query: %+v", err)
+	}
+	if len(results) != 1 {
+		t.Errorf("Wrong length of results: %d: %+v", len(results), results)
+	}
+	if results[0] != recordTest {
+		t.Errorf("Unexpected result: %+v", results[0])
+	}
+	loadRecordsIntoDB()
+	results, err = returnForQuery([]string{"linux:networking:test"}, 2)
+	if err != nil {
+		t.Errorf("Error for returning test query: %+v", err)
+	}
+	if len(results) != 2 {
+		t.Errorf("Wrong length of test results: %d: %+v", len(results), results)
+	}
+}