main
1package api
2
3import (
4 "context"
5 "fmt"
6 "mysh/pkg/mythic"
7 "testing"
8 "time"
9)
10
11// MockTaskClient extends MockClient to support task operations
12type MockTaskClient struct {
13 MockClient
14 tasks map[int]*mythic.Task
15 createErr error
16 getTaskErr error
17 taskCounter int
18}
19
20func (m *MockTaskClient) CreateTask(ctx context.Context, callbackID int, command, params string) (*mythic.Task, error) {
21 if m.createErr != nil {
22 return nil, m.createErr
23 }
24
25 m.taskCounter++
26 task := &mythic.Task{
27 ID: m.taskCounter,
28 DisplayID: m.taskCounter * 10,
29 Command: command,
30 Params: params,
31 Status: TaskStatusSubmitted,
32 CallbackID: callbackID,
33 }
34
35 if m.tasks == nil {
36 m.tasks = make(map[int]*mythic.Task)
37 }
38 m.tasks[task.ID] = task
39
40 return task, nil
41}
42
43func (m *MockTaskClient) GetTaskResponse(ctx context.Context, taskID int) (*mythic.Task, error) {
44 if m.getTaskErr != nil {
45 return nil, m.getTaskErr
46 }
47
48 task, exists := m.tasks[taskID]
49 if !exists {
50 return nil, fmt.Errorf("task %d not found", taskID)
51 }
52
53 // Return a copy to simulate potential status changes
54 return &mythic.Task{
55 ID: task.ID,
56 DisplayID: task.DisplayID,
57 Command: task.Command,
58 Params: task.Params,
59 Status: task.Status,
60 Response: task.Response,
61 CallbackID: task.CallbackID,
62 Completed: task.Status == TaskStatusCompleted || task.Status == TaskStatusError,
63 }, nil
64}
65
66// SetTaskCompleted simulates a task completing
67func (m *MockTaskClient) SetTaskCompleted(taskID int, response string) {
68 if task, exists := m.tasks[taskID]; exists {
69 task.Status = TaskStatusCompleted
70 task.Response = response
71 task.Completed = true
72 }
73}
74
75func (m *MockTaskClient) GetTasksWithResponses(ctx context.Context, callbackID int, limit int) ([]mythic.Task, error) {
76 return nil, nil // Not implemented for this test
77}
78
79func (m *MockTaskClient) GetAllTasksWithResponses(ctx context.Context, limit int) ([]mythic.Task, error) {
80 return nil, nil // Not implemented for this test
81}
82
83func TestDefaultTaskPollConfig(t *testing.T) {
84 config := DefaultTaskPollConfig()
85
86 if config.TimeoutSeconds != int(DefaultTaskTimeout.Seconds()) {
87 t.Errorf("Expected TimeoutSeconds to be %d, got %d", int(DefaultTaskTimeout.Seconds()), config.TimeoutSeconds)
88 }
89 if config.PollInterval != DefaultPollInterval {
90 t.Errorf("Expected PollInterval to be %v, got %v", DefaultPollInterval, config.PollInterval)
91 }
92 if !config.ShowProgress {
93 t.Error("Expected ShowProgress to be true")
94 }
95 if config.RawOutput {
96 t.Error("Expected RawOutput to be false")
97 }
98}
99
100func TestExecuteTaskAndWait_Success(t *testing.T) {
101 client := &MockTaskClient{}
102 ctx := context.Background()
103
104 config := TaskPollConfig{
105 TimeoutSeconds: 5,
106 PollInterval: 100 * time.Millisecond,
107 ShowProgress: false,
108 RawOutput: true,
109 }
110
111 // Start task execution in background
112 go func() {
113 time.Sleep(200 * time.Millisecond)
114 // Simulate task completion after a short delay
115 client.SetTaskCompleted(1, "Task completed successfully")
116 }()
117
118 result, err := ExecuteTaskAndWait(ctx, client, 1, "ls", "-la", config)
119
120 if err != nil {
121 t.Errorf("Unexpected error: %v", err)
122 }
123 if result == nil {
124 t.Fatal("Expected task result, got nil")
125 }
126 if result.Status != TaskStatusCompleted {
127 t.Errorf("Expected status %q, got %q", TaskStatusCompleted, result.Status)
128 }
129 if result.Response != "Task completed successfully" {
130 t.Errorf("Expected response %q, got %q", "Task completed successfully", result.Response)
131 }
132}
133
134func TestExecuteTaskAndWait_CreateTaskError(t *testing.T) {
135 client := &MockTaskClient{
136 createErr: fmt.Errorf("failed to create task"),
137 }
138 ctx := context.Background()
139
140 config := DefaultTaskPollConfig()
141
142 result, err := ExecuteTaskAndWait(ctx, client, 1, "ls", "-la", config)
143
144 if err == nil {
145 t.Error("Expected error, got nil")
146 }
147 if result != nil {
148 t.Error("Expected nil result on error")
149 }
150}
151
152func TestExecuteTaskAndWait_Timeout(t *testing.T) {
153 client := &MockTaskClient{}
154 ctx := context.Background()
155
156 config := TaskPollConfig{
157 TimeoutSeconds: 1, // Very short timeout
158 PollInterval: 100 * time.Millisecond,
159 ShowProgress: false,
160 RawOutput: true,
161 }
162
163 // Don't complete the task - let it timeout
164 result, err := ExecuteTaskAndWait(ctx, client, 1, "ls", "-la", config)
165
166 if err == nil {
167 t.Error("Expected timeout error, got nil")
168 }
169 if result != nil {
170 t.Error("Expected nil result on timeout")
171 }
172}