main
Raw Download raw file
  1package api
  2
  3import (
  4	"context"
  5	"mysh/pkg/mythic"
  6	"testing"
  7)
  8
  9// MockClient implements the MythicClient interface for testing
 10type MockClient struct {
 11	callbacks []mythic.Callback
 12	err       error
 13}
 14
 15func (m *MockClient) GetActiveCallbacks(ctx context.Context) ([]mythic.Callback, error) {
 16	if m.err != nil {
 17		return nil, m.err
 18	}
 19	return m.callbacks, nil
 20}
 21
 22func (m *MockClient) CreateTask(ctx context.Context, callbackID int, command, params string) (*mythic.Task, error) {
 23	return nil, nil // Not implemented for this test
 24}
 25
 26func (m *MockClient) GetTaskResponse(ctx context.Context, taskID int) (*mythic.Task, error) {
 27	return nil, nil // Not implemented for this test
 28}
 29
 30func (m *MockClient) GetTasksWithResponses(ctx context.Context, callbackID int, limit int) ([]mythic.Task, error) {
 31	return nil, nil // Not implemented for this test
 32}
 33
 34func (m *MockClient) GetAllTasksWithResponses(ctx context.Context, limit int) ([]mythic.Task, error) {
 35	return nil, nil // Not implemented for this test
 36}
 37
 38func TestFindActiveCallback(t *testing.T) {
 39	tests := []struct {
 40		name       string
 41		callbacks  []mythic.Callback
 42		callbackID int
 43		expectErr  bool
 44		expectedID int
 45	}{
 46		{
 47			name: "callback found",
 48			callbacks: []mythic.Callback{
 49				{ID: 1, DisplayID: 10, Host: "host1", User: "user1"},
 50				{ID: 2, DisplayID: 20, Host: "host2", User: "user2"},
 51			},
 52			callbackID: 20,
 53			expectErr:  false,
 54			expectedID: 2,
 55		},
 56		{
 57			name: "callback not found",
 58			callbacks: []mythic.Callback{
 59				{ID: 1, DisplayID: 10, Host: "host1", User: "user1"},
 60			},
 61			callbackID: 99,
 62			expectErr:  true,
 63		},
 64		{
 65			name:       "no callbacks",
 66			callbacks:  []mythic.Callback{},
 67			callbackID: 10,
 68			expectErr:  true,
 69		},
 70	}
 71
 72	for _, tt := range tests {
 73		t.Run(tt.name, func(t *testing.T) {
 74			client := &MockClient{callbacks: tt.callbacks}
 75			ctx := context.Background()
 76
 77			result, err := FindActiveCallback(ctx, client, tt.callbackID)
 78
 79			if tt.expectErr {
 80				if err == nil {
 81					t.Error("Expected error, but got nil")
 82				}
 83				if result != nil {
 84					t.Error("Expected nil result on error")
 85				}
 86			} else {
 87				if err != nil {
 88					t.Errorf("Unexpected error: %v", err)
 89				}
 90				if result == nil {
 91					t.Error("Expected callback, but got nil")
 92				} else if result.ID != tt.expectedID {
 93					t.Errorf("Expected callback ID %d, got %d", tt.expectedID, result.ID)
 94				}
 95			}
 96		})
 97	}
 98}
 99
100func TestValidateCallbackExists(t *testing.T) {
101	callbacks := []mythic.Callback{
102		{ID: 1, DisplayID: 10, Host: "host1", User: "user1"},
103	}
104	client := &MockClient{callbacks: callbacks}
105	ctx := context.Background()
106
107	// Test valid callback
108	result, err := ValidateCallbackExists(ctx, client, 10)
109	if err != nil {
110		t.Errorf("Unexpected error: %v", err)
111	}
112	if result == nil {
113		t.Error("Expected callback, but got nil")
114	}
115
116	// Test invalid callback
117	result, err = ValidateCallbackExists(ctx, client, 99)
118	if err == nil {
119		t.Error("Expected error for non-existent callback")
120	}
121	if result != nil {
122		t.Error("Expected nil result for non-existent callback")
123	}
124}