main
Raw Download raw file
  1package api
  2
  3import (
  4	"context"
  5	"fmt"
  6	"mysh/pkg/mythic"
  7	"testing"
  8	"time"
  9)
 10
 11// MockTaskClient extends MockClient to support task operations
 12type MockTaskClient struct {
 13	MockClient
 14	tasks       map[int]*mythic.Task
 15	createErr   error
 16	getTaskErr  error
 17	taskCounter int
 18}
 19
 20func (m *MockTaskClient) CreateTask(ctx context.Context, callbackID int, command, params string) (*mythic.Task, error) {
 21	if m.createErr != nil {
 22		return nil, m.createErr
 23	}
 24
 25	m.taskCounter++
 26	task := &mythic.Task{
 27		ID:         m.taskCounter,
 28		DisplayID:  m.taskCounter * 10,
 29		Command:    command,
 30		Params:     params,
 31		Status:     TaskStatusSubmitted,
 32		CallbackID: callbackID,
 33	}
 34
 35	if m.tasks == nil {
 36		m.tasks = make(map[int]*mythic.Task)
 37	}
 38	m.tasks[task.ID] = task
 39
 40	return task, nil
 41}
 42
 43func (m *MockTaskClient) GetTaskResponse(ctx context.Context, taskID int) (*mythic.Task, error) {
 44	if m.getTaskErr != nil {
 45		return nil, m.getTaskErr
 46	}
 47
 48	task, exists := m.tasks[taskID]
 49	if !exists {
 50		return nil, fmt.Errorf("task %d not found", taskID)
 51	}
 52
 53	// Return a copy to simulate potential status changes
 54	return &mythic.Task{
 55		ID:         task.ID,
 56		DisplayID:  task.DisplayID,
 57		Command:    task.Command,
 58		Params:     task.Params,
 59		Status:     task.Status,
 60		Response:   task.Response,
 61		CallbackID: task.CallbackID,
 62		Completed:  task.Status == TaskStatusCompleted || task.Status == TaskStatusError,
 63	}, nil
 64}
 65
 66// SetTaskCompleted simulates a task completing
 67func (m *MockTaskClient) SetTaskCompleted(taskID int, response string) {
 68	if task, exists := m.tasks[taskID]; exists {
 69		task.Status = TaskStatusCompleted
 70		task.Response = response
 71		task.Completed = true
 72	}
 73}
 74
 75func (m *MockTaskClient) GetTasksWithResponses(ctx context.Context, callbackID int, limit int) ([]mythic.Task, error) {
 76	return nil, nil // Not implemented for this test
 77}
 78
 79func (m *MockTaskClient) GetAllTasksWithResponses(ctx context.Context, limit int) ([]mythic.Task, error) {
 80	return nil, nil // Not implemented for this test
 81}
 82
 83func TestDefaultTaskPollConfig(t *testing.T) {
 84	config := DefaultTaskPollConfig()
 85
 86	if config.TimeoutSeconds != int(DefaultTaskTimeout.Seconds()) {
 87		t.Errorf("Expected TimeoutSeconds to be %d, got %d", int(DefaultTaskTimeout.Seconds()), config.TimeoutSeconds)
 88	}
 89	if config.PollInterval != DefaultPollInterval {
 90		t.Errorf("Expected PollInterval to be %v, got %v", DefaultPollInterval, config.PollInterval)
 91	}
 92	if !config.ShowProgress {
 93		t.Error("Expected ShowProgress to be true")
 94	}
 95	if config.RawOutput {
 96		t.Error("Expected RawOutput to be false")
 97	}
 98}
 99
100func TestExecuteTaskAndWait_Success(t *testing.T) {
101	client := &MockTaskClient{}
102	ctx := context.Background()
103
104	config := TaskPollConfig{
105		TimeoutSeconds: 5,
106		PollInterval:   100 * time.Millisecond,
107		ShowProgress:   false,
108		RawOutput:      true,
109	}
110
111	// Start task execution in background
112	go func() {
113		time.Sleep(200 * time.Millisecond)
114		// Simulate task completion after a short delay
115		client.SetTaskCompleted(1, "Task completed successfully")
116	}()
117
118	result, err := ExecuteTaskAndWait(ctx, client, 1, "ls", "-la", config)
119
120	if err != nil {
121		t.Errorf("Unexpected error: %v", err)
122	}
123	if result == nil {
124		t.Fatal("Expected task result, got nil")
125	}
126	if result.Status != TaskStatusCompleted {
127		t.Errorf("Expected status %q, got %q", TaskStatusCompleted, result.Status)
128	}
129	if result.Response != "Task completed successfully" {
130		t.Errorf("Expected response %q, got %q", "Task completed successfully", result.Response)
131	}
132}
133
134func TestExecuteTaskAndWait_CreateTaskError(t *testing.T) {
135	client := &MockTaskClient{
136		createErr: fmt.Errorf("failed to create task"),
137	}
138	ctx := context.Background()
139
140	config := DefaultTaskPollConfig()
141
142	result, err := ExecuteTaskAndWait(ctx, client, 1, "ls", "-la", config)
143
144	if err == nil {
145		t.Error("Expected error, got nil")
146	}
147	if result != nil {
148		t.Error("Expected nil result on error")
149	}
150}
151
152func TestExecuteTaskAndWait_Timeout(t *testing.T) {
153	client := &MockTaskClient{}
154	ctx := context.Background()
155
156	config := TaskPollConfig{
157		TimeoutSeconds: 1, // Very short timeout
158		PollInterval:   100 * time.Millisecond,
159		ShowProgress:   false,
160		RawOutput:      true,
161	}
162
163	// Don't complete the task - let it timeout
164	result, err := ExecuteTaskAndWait(ctx, client, 1, "ls", "-la", config)
165
166	if err == nil {
167		t.Error("Expected timeout error, got nil")
168	}
169	if result != nil {
170		t.Error("Expected nil result on timeout")
171	}
172}