main
1package cache
2
3import (
4 "crypto/sha256"
5 "encoding/json"
6 "fmt"
7 "mysh/pkg/mythic"
8 "net/url"
9 "os"
10 "path/filepath"
11 "strings"
12 "time"
13)
14
15// TaskCache manages caching of completed task results
16type TaskCache struct {
17 cacheDir string
18}
19
20// extractHostname safely extracts hostname from a server URL for directory naming
21func extractHostname(serverURL string) string {
22 // Parse the URL to extract hostname
23 parsedURL, err := url.Parse(serverURL)
24 if err != nil {
25 // If parsing fails, use a sanitized version of the full URL
26 return sanitizeForPath(serverURL)
27 }
28
29 hostname := parsedURL.Hostname()
30 if hostname == "" {
31 // Fallback to host (includes port if present)
32 hostname = parsedURL.Host
33 }
34
35 if hostname == "" {
36 // Final fallback to sanitized URL
37 return sanitizeForPath(serverURL)
38 }
39
40 return sanitizeForPath(hostname)
41}
42
43// sanitizeForPath removes characters that aren't safe for directory names
44func sanitizeForPath(input string) string {
45 // Replace unsafe characters with underscores
46 unsafe := []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " "}
47 result := input
48 for _, char := range unsafe {
49 result = strings.ReplaceAll(result, char, "_")
50 }
51 return result
52}
53
54// New creates a new TaskCache instance using XDG cache directory with server-specific subdirectory
55func New(serverURL string) (*TaskCache, error) {
56 // Use XDG_CACHE_HOME if set, otherwise use default ~/.cache
57 cacheHome := os.Getenv("XDG_CACHE_HOME")
58 if cacheHome == "" {
59 homeDir, err := os.UserHomeDir()
60 if err != nil {
61 return nil, fmt.Errorf("failed to get user home directory: %w", err)
62 }
63 cacheHome = filepath.Join(homeDir, ".cache")
64 }
65
66 // Extract hostname for server-specific directory
67 hostname := extractHostname(serverURL)
68 cacheDir := filepath.Join(cacheHome, "mysh", hostname)
69
70 // Create cache directory if it doesn't exist
71 if err := os.MkdirAll(cacheDir, 0755); err != nil {
72 return nil, fmt.Errorf("failed to create cache directory: %w", err)
73 }
74
75 return &TaskCache{cacheDir: cacheDir}, nil
76}
77
78// CachedTask represents a cached task with metadata
79type CachedTask struct {
80 Task *mythic.Task `json:"task"`
81 CachedAt time.Time `json:"cached_at"`
82 ServerURL string `json:"server_url"`
83}
84
85// generateCacheKey creates a unique cache key for a task
86func (tc *TaskCache) generateCacheKey(taskID int, serverURL string) string {
87 // Create a hash based on task ID and server URL to ensure uniqueness
88 h := sha256.New()
89 h.Write([]byte(fmt.Sprintf("%d:%s", taskID, serverURL)))
90 return fmt.Sprintf("task_%d_%x.json", taskID, h.Sum(nil)[:8])
91}
92
93// GetCachedTask retrieves a cached task if it exists and is for a completed task
94func (tc *TaskCache) GetCachedTask(taskID int, serverURL string) (*mythic.Task, bool) {
95 cacheKey := tc.generateCacheKey(taskID, serverURL)
96 cachePath := filepath.Join(tc.cacheDir, cacheKey)
97
98 // Check if cache file exists
99 if _, err := os.Stat(cachePath); os.IsNotExist(err) {
100 return nil, false
101 }
102
103 // Read cache file
104 file, err := os.Open(cachePath)
105 if err != nil {
106 return nil, false
107 }
108 defer file.Close()
109
110 // Decode cached task
111 var cachedTask CachedTask
112 if err := json.NewDecoder(file).Decode(&cachedTask); err != nil {
113 // If cache is corrupted, remove it
114 os.Remove(cachePath)
115 return nil, false
116 }
117
118 // Verify this cache is for the same server
119 if cachedTask.ServerURL != serverURL {
120 return nil, false
121 }
122
123 // Only return cached results for completed tasks
124 if !cachedTask.Task.Completed && cachedTask.Task.Status != "completed" && cachedTask.Task.Status != "error" {
125 return nil, false
126 }
127
128 return cachedTask.Task, true
129}
130
131// CacheTask stores a completed task result in cache
132func (tc *TaskCache) CacheTask(task *mythic.Task, serverURL string) error {
133 // Only cache completed tasks
134 if !task.Completed && task.Status != "completed" && task.Status != "error" {
135 return nil
136 }
137
138 cacheKey := tc.generateCacheKey(task.ID, serverURL)
139 cachePath := filepath.Join(tc.cacheDir, cacheKey)
140
141 // Create cache entry
142 cachedTask := CachedTask{
143 Task: task,
144 CachedAt: time.Now(),
145 ServerURL: serverURL,
146 }
147
148 // Write to temporary file first, then rename (atomic operation)
149 tempPath := cachePath + ".tmp"
150 file, err := os.Create(tempPath)
151 if err != nil {
152 return fmt.Errorf("failed to create cache file: %w", err)
153 }
154 defer file.Close()
155
156 if err := json.NewEncoder(file).Encode(cachedTask); err != nil {
157 os.Remove(tempPath)
158 return fmt.Errorf("failed to encode cache data: %w", err)
159 }
160
161 // Atomic rename
162 if err := os.Rename(tempPath, cachePath); err != nil {
163 os.Remove(tempPath)
164 return fmt.Errorf("failed to finalize cache file: %w", err)
165 }
166
167 return nil
168}
169
170// CleanOldCache removes cache entries older than the specified duration
171func (tc *TaskCache) CleanOldCache(maxAge time.Duration) error {
172 entries, err := os.ReadDir(tc.cacheDir)
173 if err != nil {
174 return fmt.Errorf("failed to read cache directory: %w", err)
175 }
176
177 cutoff := time.Now().Add(-maxAge)
178
179 for _, entry := range entries {
180 if !entry.IsDir() && filepath.Ext(entry.Name()) == ".json" {
181 cachePath := filepath.Join(tc.cacheDir, entry.Name())
182
183 // Check file modification time
184 info, err := entry.Info()
185 if err != nil {
186 continue
187 }
188
189 if info.ModTime().Before(cutoff) {
190 os.Remove(cachePath)
191 }
192 }
193 }
194
195 return nil
196}
197
198// GetCacheInfo returns information about the cache directory
199func (tc *TaskCache) GetCacheInfo() (string, int, int64, error) {
200 entries, err := os.ReadDir(tc.cacheDir)
201 if err != nil {
202 return tc.cacheDir, 0, 0, fmt.Errorf("failed to read cache directory: %w", err)
203 }
204
205 var totalSize int64
206 fileCount := 0
207
208 for _, entry := range entries {
209 if !entry.IsDir() && filepath.Ext(entry.Name()) == ".json" {
210 info, err := entry.Info()
211 if err == nil {
212 totalSize += info.Size()
213 fileCount++
214 }
215 }
216 }
217
218 return tc.cacheDir, fileCount, totalSize, nil
219}
220
221// ClearCache removes all cached task results
222func (tc *TaskCache) ClearCache() error {
223 entries, err := os.ReadDir(tc.cacheDir)
224 if err != nil {
225 return fmt.Errorf("failed to read cache directory: %w", err)
226 }
227
228 for _, entry := range entries {
229 if !entry.IsDir() && filepath.Ext(entry.Name()) == ".json" {
230 cachePath := filepath.Join(tc.cacheDir, entry.Name())
231 if err := os.Remove(cachePath); err != nil {
232 return fmt.Errorf("failed to remove cache file %s: %w", entry.Name(), err)
233 }
234 }
235 }
236
237 return nil
238}