main
1package api
2
3import (
4 "context"
5 "mysh/pkg/mythic"
6 "testing"
7)
8
9// MockClient implements the MythicClient interface for testing
10type MockClient struct {
11 callbacks []mythic.Callback
12 err error
13}
14
15func (m *MockClient) GetActiveCallbacks(ctx context.Context) ([]mythic.Callback, error) {
16 if m.err != nil {
17 return nil, m.err
18 }
19 return m.callbacks, nil
20}
21
22func (m *MockClient) CreateTask(ctx context.Context, callbackID int, command, params string) (*mythic.Task, error) {
23 return nil, nil // Not implemented for this test
24}
25
26func (m *MockClient) GetTaskResponse(ctx context.Context, taskID int) (*mythic.Task, error) {
27 return nil, nil // Not implemented for this test
28}
29
30func (m *MockClient) GetTasksWithResponses(ctx context.Context, callbackID int, limit int) ([]mythic.Task, error) {
31 return nil, nil // Not implemented for this test
32}
33
34func (m *MockClient) GetAllTasksWithResponses(ctx context.Context, limit int) ([]mythic.Task, error) {
35 return nil, nil // Not implemented for this test
36}
37
38func TestFindActiveCallback(t *testing.T) {
39 tests := []struct {
40 name string
41 callbacks []mythic.Callback
42 callbackID int
43 expectErr bool
44 expectedID int
45 }{
46 {
47 name: "callback found",
48 callbacks: []mythic.Callback{
49 {ID: 1, DisplayID: 10, Host: "host1", User: "user1"},
50 {ID: 2, DisplayID: 20, Host: "host2", User: "user2"},
51 },
52 callbackID: 20,
53 expectErr: false,
54 expectedID: 2,
55 },
56 {
57 name: "callback not found",
58 callbacks: []mythic.Callback{
59 {ID: 1, DisplayID: 10, Host: "host1", User: "user1"},
60 },
61 callbackID: 99,
62 expectErr: true,
63 },
64 {
65 name: "no callbacks",
66 callbacks: []mythic.Callback{},
67 callbackID: 10,
68 expectErr: true,
69 },
70 }
71
72 for _, tt := range tests {
73 t.Run(tt.name, func(t *testing.T) {
74 client := &MockClient{callbacks: tt.callbacks}
75 ctx := context.Background()
76
77 result, err := FindActiveCallback(ctx, client, tt.callbackID)
78
79 if tt.expectErr {
80 if err == nil {
81 t.Error("Expected error, but got nil")
82 }
83 if result != nil {
84 t.Error("Expected nil result on error")
85 }
86 } else {
87 if err != nil {
88 t.Errorf("Unexpected error: %v", err)
89 }
90 if result == nil {
91 t.Error("Expected callback, but got nil")
92 } else if result.ID != tt.expectedID {
93 t.Errorf("Expected callback ID %d, got %d", tt.expectedID, result.ID)
94 }
95 }
96 })
97 }
98}
99
100func TestValidateCallbackExists(t *testing.T) {
101 callbacks := []mythic.Callback{
102 {ID: 1, DisplayID: 10, Host: "host1", User: "user1"},
103 }
104 client := &MockClient{callbacks: callbacks}
105 ctx := context.Background()
106
107 // Test valid callback
108 result, err := ValidateCallbackExists(ctx, client, 10)
109 if err != nil {
110 t.Errorf("Unexpected error: %v", err)
111 }
112 if result == nil {
113 t.Error("Expected callback, but got nil")
114 }
115
116 // Test invalid callback
117 result, err = ValidateCallbackExists(ctx, client, 99)
118 if err == nil {
119 t.Error("Expected error for non-existent callback")
120 }
121 if result != nil {
122 t.Error("Expected nil result for non-existent callback")
123 }
124}