Commit d938d3a
Changed files (2)
db.go
@@ -0,0 +1,83 @@
+// db holds the functions relating to the database
+package main
+
+import (
+ "gopkg.in/mgo.v2"
+ "gopkg.in/mgo.v2/bson"
+ "strconv"
+)
+
+// DB_CONFIG defines the interactions within the database
+type DB_CONFIG struct {
+ // DB_name is the name of the database
+ DB_name,
+ // CollectionName is the name of the collection within the database
+ CollectionName string
+}
+
+// DB_config stores the configuration data for the database
+var DB_config = DB_CONFIG{
+ DB_name: "records",
+ CollectionName: "cyber"}
+
+// DBsession contains the session information for the connection
+// to the database
+var DBsession *mgo.Session
+
+// initDBSession initializes the connection to the database containing
+// the list of records.
+// returns the *mgo.Session connected to the database
+func InitDBSession() (*mgo.Session, error) {
+ var err error
+ if serverConfig.DB_PORT != 0 {
+ DBsession, err = mgo.Dial(serverConfig.DB_ADDRESS + ":" + strconv.Itoa(serverConfig.DB_PORT))
+ }
+ if err != nil {
+ return nil, err
+ }
+ index := mgo.Index{Key: []string{"categorystring"},
+ DropDups: true,
+ Unique: true}
+ DBsession.DB(DB_config.DB_name).C(DB_config.CollectionName).EnsureIndex(index)
+ return DBsession, nil
+}
+
+// loadRecordsIntoDB takes the records in memory and loads them into the database
+func loadRecordsIntoDB() {
+ for _, category := range categories {
+ for _, record := range category.RecordArray {
+ addRecordToDB(&record)
+ }
+ for _, record := range category.UsedRecordArray {
+ addRecordToDB(&record)
+ }
+ }
+}
+
+// returnForQuery
+func returnForQuery(selectedCategories []string, numQuestions int) ([]Record, error) {
+ collection := DBsession.DB(DB_config.DB_name).C(DB_config.CollectionName)
+ var results []Record
+ var collector []Record
+ var err error
+ for _, category := range selectedCategories {
+ err = collection.Find(bson.M{"categorystring": category}).All(&collector)
+ if err != nil {
+ return nil, err
+ } else {
+ results = append(results, collector...)
+ }
+ }
+ return results, nil
+}
+
+// addRecordToDB
+func addRecordToDB(record *Record) error {
+ if DBsession == nil {
+ if _, err := InitDBSession(); err != nil {
+ return err
+ }
+ }
+ err := DBsession.DB(DB_config.DB_name).C(DB_config.CollectionName).Insert(record)
+ return err
+}
db_test.go
@@ -0,0 +1,53 @@
+// tests the database connections
+package main
+
+import (
+ "gopkg.in/mgo.v2"
+ "strings"
+ "testing"
+)
+
+// initDBForTest creates a session for the database testing
+func initDBForTest(t *testing.T) *mgo.Session {
+ serverConfig = SERVER_CONFIG{DB_ADDRESS: "127.0.0.1",
+ DB_PORT: 27017}
+ session, err := InitDBSession()
+ if err != nil {
+ t.Errorf("Error for initDB: %+v", err)
+ t.FailNow()
+ }
+ return session
+}
+
+// Test_AddAndRetriveForDB
+func Test_AddAndRetriveForDB(t *testing.T) {
+ t.SkipNow()
+ DBsession = initDBForTest(t)
+ var recordTest = Record{Question: "Question",
+ Reference: "Reference",
+ Answer: "Answer",
+ Path: ""}
+ if err := addRecordToDB(&recordTest); err != nil {
+ if !strings.Contains(err.Error(), "duplicate key error") {
+ t.Errorf("Error for adding record: %+v", err)
+ }
+ }
+ results, err := returnForQuery([]string{"CategoryString"}, 1)
+ if err != nil {
+ t.Errorf("Error for returning initial query: %+v", err)
+ }
+ if len(results) != 1 {
+ t.Errorf("Wrong length of results: %d: %+v", len(results), results)
+ }
+ if results[0] != recordTest {
+ t.Errorf("Unexpected result: %+v", results[0])
+ }
+ loadRecordsIntoDB()
+ results, err = returnForQuery([]string{"linux:networking:test"}, 2)
+ if err != nil {
+ t.Errorf("Error for returning test query: %+v", err)
+ }
+ if len(results) != 2 {
+ t.Errorf("Wrong length of test results: %d: %+v", len(results), results)
+ }
+}