main
Raw Download raw file
  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"mysh/pkg/cache"
  7	"mysh/pkg/mythic"
  8	"mysh/pkg/mythic/api"
  9	"sort"
 10	"strconv"
 11	"strings"
 12	"time"
 13
 14	"github.com/spf13/cobra"
 15)
 16
 17var (
 18	taskWaitTime  int
 19	taskRawOutput bool
 20	taskNoWait    bool
 21	taskCopyID    int
 22)
 23
 24var taskCmd = &cobra.Command{
 25	Use:     "task <callback_id(s)> [command] [args...]",
 26	Aliases: []string{"t", "exec"},
 27	Short:   "Send a command to one or more agents",
 28	Long:    "Send a command to specified agent callback(s) and wait for response. Supports single IDs, comma-separated lists, and ranges (e.g., '1,3-5,10'). Use --copy-task to copy parameters from an existing task.",
 29	Args:    cobra.MinimumNArgs(1),
 30	RunE:    runTask,
 31}
 32
 33func init() {
 34	rootCmd.AddCommand(taskCmd)
 35	taskCmd.Flags().IntVarP(&taskWaitTime, "wait", "w", 30, "Maximum time to wait for response (seconds)")
 36	taskCmd.Flags().BoolVar(&taskRawOutput, "raw", false, "Output only raw response bytes")
 37	taskCmd.Flags().BoolVar(&taskNoWait, "no-wait", false, "Create task but don't wait for response")
 38	taskCmd.Flags().IntVar(&taskCopyID, "copy-task", 0, "Copy command and parameters from existing task ID")
 39}
 40
 41// parseCallbackIDs parses a callback ID string that can contain:
 42// - Single ID: "1"
 43// - Comma-separated list: "1,2,3"
 44// - Ranges: "1-5" (expands to 1,2,3,4,5)
 45// - Mixed: "1,3-5,10" (expands to 1,3,4,5,10)
 46func parseCallbackIDs(input string) ([]int, error) {
 47	var callbackIDs []int
 48	seen := make(map[int]bool)
 49
 50	parts := strings.Split(input, ",")
 51	for _, part := range parts {
 52		part = strings.TrimSpace(part)
 53		if part == "" {
 54			continue
 55		}
 56
 57		if strings.Contains(part, "-") {
 58			// Handle range (e.g., "3-5")
 59			rangeParts := strings.Split(part, "-")
 60			if len(rangeParts) != 2 {
 61				return nil, fmt.Errorf("invalid range format: %s", part)
 62			}
 63
 64			start, err := strconv.Atoi(strings.TrimSpace(rangeParts[0]))
 65			if err != nil {
 66				return nil, fmt.Errorf("invalid start of range: %s", rangeParts[0])
 67			}
 68
 69			end, err := strconv.Atoi(strings.TrimSpace(rangeParts[1]))
 70			if err != nil {
 71				return nil, fmt.Errorf("invalid end of range: %s", rangeParts[1])
 72			}
 73
 74			if start > end {
 75				return nil, fmt.Errorf("invalid range: start (%d) is greater than end (%d)", start, end)
 76			}
 77
 78			// Add all IDs in range
 79			for i := start; i <= end; i++ {
 80				if !seen[i] {
 81					callbackIDs = append(callbackIDs, i)
 82					seen[i] = true
 83				}
 84			}
 85		} else {
 86			// Handle single ID
 87			id, err := strconv.Atoi(part)
 88			if err != nil {
 89				return nil, fmt.Errorf("invalid callback ID: %s", part)
 90			}
 91
 92			if !seen[id] {
 93				callbackIDs = append(callbackIDs, id)
 94				seen[id] = true
 95			}
 96		}
 97	}
 98
 99	if len(callbackIDs) == 0 {
100		return nil, fmt.Errorf("no valid callback IDs found")
101	}
102
103	// Sort the IDs for consistent output
104	sort.Ints(callbackIDs)
105	return callbackIDs, nil
106}
107
108func runTask(cmd *cobra.Command, args []string) error {
109	if err := validateConfig(); err != nil {
110		return err
111	}
112
113	// Parse callback IDs (supports lists and ranges)
114	callbackIDs, err := parseCallbackIDs(args[0])
115	if err != nil {
116		return fmt.Errorf("invalid callback ID(s): %w", err)
117	}
118
119	var command, params string
120
121	client := mythic.NewClient(mythicURL, token, insecure, socksProxy)
122	ctx := context.Background()
123
124	// Handle task copying
125	if taskCopyID > 0 {
126		// Copy task parameters from existing task
127		copyCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
128		defer cancel()
129
130		// Initialize cache for task lookup
131		taskCache, err := cache.New(mythicURL)
132		if err != nil {
133			taskCache = nil
134		}
135
136		// Try to get task from cache first
137		var sourceTask *mythic.Task
138		if taskCache != nil {
139			if cachedTask, found := taskCache.GetCachedTask(taskCopyID, mythicURL); found {
140				sourceTask = cachedTask
141			}
142		}
143
144		// If not in cache, fetch from server
145		if sourceTask == nil {
146			sourceTask, err = client.GetTaskResponse(copyCtx, taskCopyID)
147			if err != nil {
148				return fmt.Errorf("failed to get source task %d: %w", taskCopyID, err)
149			}
150
151			// Cache the result if task is completed and cache is available
152			if taskCache != nil {
153				taskCache.CacheTask(sourceTask, mythicURL)
154			}
155		}
156
157		// Use the source task's command and original parameters
158		command = sourceTask.Command
159		if sourceTask.OriginalParams != "" {
160			params = sourceTask.OriginalParams
161		} else if sourceTask.Params != "" {
162			params = sourceTask.Params
163		}
164
165		fmt.Printf("Copying task %d: %s\n", taskCopyID, command)
166		if params != "" {
167			fmt.Printf("Parameters: %s\n", params)
168		}
169		fmt.Println()
170	} else {
171		// Normal mode: use command line arguments
172		if len(args) < 2 {
173			return fmt.Errorf("command is required when not using --copy-task")
174		}
175		command = args[1]
176		if len(args) > 2 {
177			params = strings.Join(args[2:], " ")
178		}
179	}
180
181	// Track results for multiple callbacks
182	type TaskResult struct {
183		CallbackID int
184		Callback   *mythic.Callback
185		Task       *mythic.Task
186		Error      error
187	}
188
189	var results []TaskResult
190
191	// For multiple callbacks, automatically use no-wait mode
192	useNoWait := taskNoWait || len(callbackIDs) > 1
193
194	// Process each callback
195	for _, callbackID := range callbackIDs {
196		result := TaskResult{CallbackID: callbackID}
197
198		// Verify callback exists and is active
199		targetCallback, err := api.FindActiveCallback(ctx, client, callbackID)
200		if err != nil {
201			result.Error = err
202			results = append(results, result)
203			continue
204		}
205		result.Callback = targetCallback
206
207		if useNoWait {
208			// Create task but don't wait for response
209			task, err := client.CreateTask(ctx, targetCallback.ID, command, params)
210			if err != nil {
211				result.Error = fmt.Errorf("failed to create task: %w", err)
212			} else {
213				result.Task = task
214			}
215			results = append(results, result)
216		} else {
217			// Configure polling (only for single callback)
218			config := api.TaskPollConfig{
219				TimeoutSeconds: taskWaitTime,
220				PollInterval:   api.DefaultPollInterval,
221				ShowProgress:   !taskRawOutput,
222				RawOutput:      taskRawOutput,
223			}
224
225			// Execute task and wait for response
226			task, err := api.ExecuteTaskAndWait(ctx, client, targetCallback.ID, command, params, config)
227			if err != nil {
228				result.Error = err
229			} else {
230				result.Task = task
231			}
232			results = append(results, result)
233		}
234	}
235
236	// Display results
237	if taskRawOutput {
238		// Raw output: only print responses
239		for _, result := range results {
240			if result.Task != nil && result.Task.Response != "" {
241				fmt.Print(result.Task.Response)
242			}
243		}
244	} else {
245		// Normal output: show detailed results
246		if len(callbackIDs) == 1 {
247			// Single callback: use original format
248			result := results[0]
249			if result.Error != nil {
250				return result.Error
251			}
252
253			if useNoWait {
254				fmt.Printf("Task %d created on callback %d (%s@%s) - not waiting for response\n",
255					result.Task.DisplayID, result.CallbackID, result.Callback.User, result.Callback.Host)
256			} else {
257				fmt.Printf("Task Status: %s\n", result.Task.Status)
258				if result.Task.Response != "" {
259					fmt.Printf("Response:\n%s\n", result.Task.Response)
260				}
261			}
262		} else {
263			// Multiple callbacks: show results for each
264			fmt.Printf("Executing '%s %s' on %d callbacks:\n\n", command, params, len(callbackIDs))
265
266			for _, result := range results {
267				fmt.Printf("=== Callback %d", result.CallbackID)
268				if result.Callback != nil {
269					fmt.Printf(" (%s@%s)", result.Callback.User, result.Callback.Host)
270				}
271				fmt.Printf(" ===\n")
272
273				if result.Error != nil {
274					fmt.Printf("Error: %v\n", result.Error)
275				} else if useNoWait {
276					fmt.Printf("Task %d created - not waiting for response\n", result.Task.DisplayID)
277				} else {
278					fmt.Printf("Status: %s\n", result.Task.Status)
279					if result.Task.Response != "" {
280						fmt.Printf("Response:\n%s\n", result.Task.Response)
281					}
282				}
283				fmt.Println()
284			}
285		}
286	}
287
288	return nil
289}