main
Raw Download raw file
  1package cache
  2
  3import (
  4	"crypto/sha256"
  5	"encoding/json"
  6	"fmt"
  7	"mysh/pkg/mythic"
  8	"net/url"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13)
 14
 15// TaskCache manages caching of completed task results
 16type TaskCache struct {
 17	cacheDir string
 18}
 19
 20// extractHostname safely extracts hostname from a server URL for directory naming
 21func extractHostname(serverURL string) string {
 22	// Parse the URL to extract hostname
 23	parsedURL, err := url.Parse(serverURL)
 24	if err != nil {
 25		// If parsing fails, use a sanitized version of the full URL
 26		return sanitizeForPath(serverURL)
 27	}
 28
 29	hostname := parsedURL.Hostname()
 30	if hostname == "" {
 31		// Fallback to host (includes port if present)
 32		hostname = parsedURL.Host
 33	}
 34
 35	if hostname == "" {
 36		// Final fallback to sanitized URL
 37		return sanitizeForPath(serverURL)
 38	}
 39
 40	return sanitizeForPath(hostname)
 41}
 42
 43// sanitizeForPath removes characters that aren't safe for directory names
 44func sanitizeForPath(input string) string {
 45	// Replace unsafe characters with underscores
 46	unsafe := []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " "}
 47	result := input
 48	for _, char := range unsafe {
 49		result = strings.ReplaceAll(result, char, "_")
 50	}
 51	return result
 52}
 53
 54// New creates a new TaskCache instance using XDG cache directory with server-specific subdirectory
 55func New(serverURL string) (*TaskCache, error) {
 56	// Use XDG_CACHE_HOME if set, otherwise use default ~/.cache
 57	cacheHome := os.Getenv("XDG_CACHE_HOME")
 58	if cacheHome == "" {
 59		homeDir, err := os.UserHomeDir()
 60		if err != nil {
 61			return nil, fmt.Errorf("failed to get user home directory: %w", err)
 62		}
 63		cacheHome = filepath.Join(homeDir, ".cache")
 64	}
 65
 66	// Extract hostname for server-specific directory
 67	hostname := extractHostname(serverURL)
 68	cacheDir := filepath.Join(cacheHome, "mysh", hostname)
 69
 70	// Create cache directory if it doesn't exist
 71	if err := os.MkdirAll(cacheDir, 0755); err != nil {
 72		return nil, fmt.Errorf("failed to create cache directory: %w", err)
 73	}
 74
 75	return &TaskCache{cacheDir: cacheDir}, nil
 76}
 77
 78// CachedTask represents a cached task with metadata
 79type CachedTask struct {
 80	Task      *mythic.Task `json:"task"`
 81	CachedAt  time.Time    `json:"cached_at"`
 82	ServerURL string       `json:"server_url"`
 83}
 84
 85// generateCacheKey creates a unique cache key for a task
 86func (tc *TaskCache) generateCacheKey(taskID int, serverURL string) string {
 87	// Create a hash based on task ID and server URL to ensure uniqueness
 88	h := sha256.New()
 89	h.Write([]byte(fmt.Sprintf("%d:%s", taskID, serverURL)))
 90	return fmt.Sprintf("task_%d_%x.json", taskID, h.Sum(nil)[:8])
 91}
 92
 93// GetCachedTask retrieves a cached task if it exists and is for a completed task
 94func (tc *TaskCache) GetCachedTask(taskID int, serverURL string) (*mythic.Task, bool) {
 95	cacheKey := tc.generateCacheKey(taskID, serverURL)
 96	cachePath := filepath.Join(tc.cacheDir, cacheKey)
 97
 98	// Check if cache file exists
 99	if _, err := os.Stat(cachePath); os.IsNotExist(err) {
100		return nil, false
101	}
102
103	// Read cache file
104	file, err := os.Open(cachePath)
105	if err != nil {
106		return nil, false
107	}
108	defer file.Close()
109
110	// Decode cached task
111	var cachedTask CachedTask
112	if err := json.NewDecoder(file).Decode(&cachedTask); err != nil {
113		// If cache is corrupted, remove it
114		os.Remove(cachePath)
115		return nil, false
116	}
117
118	// Verify this cache is for the same server
119	if cachedTask.ServerURL != serverURL {
120		return nil, false
121	}
122
123	// Only return cached results for completed tasks
124	if !cachedTask.Task.Completed && cachedTask.Task.Status != "completed" && cachedTask.Task.Status != "error" {
125		return nil, false
126	}
127
128	return cachedTask.Task, true
129}
130
131// CacheTask stores a completed task result in cache
132func (tc *TaskCache) CacheTask(task *mythic.Task, serverURL string) error {
133	// Only cache completed tasks
134	if !task.Completed && task.Status != "completed" && task.Status != "error" {
135		return nil
136	}
137
138	cacheKey := tc.generateCacheKey(task.ID, serverURL)
139	cachePath := filepath.Join(tc.cacheDir, cacheKey)
140
141	// Create cache entry
142	cachedTask := CachedTask{
143		Task:      task,
144		CachedAt:  time.Now(),
145		ServerURL: serverURL,
146	}
147
148	// Write to temporary file first, then rename (atomic operation)
149	tempPath := cachePath + ".tmp"
150	file, err := os.Create(tempPath)
151	if err != nil {
152		return fmt.Errorf("failed to create cache file: %w", err)
153	}
154	defer file.Close()
155
156	if err := json.NewEncoder(file).Encode(cachedTask); err != nil {
157		os.Remove(tempPath)
158		return fmt.Errorf("failed to encode cache data: %w", err)
159	}
160
161	// Atomic rename
162	if err := os.Rename(tempPath, cachePath); err != nil {
163		os.Remove(tempPath)
164		return fmt.Errorf("failed to finalize cache file: %w", err)
165	}
166
167	return nil
168}
169
170// CleanOldCache removes cache entries older than the specified duration
171func (tc *TaskCache) CleanOldCache(maxAge time.Duration) error {
172	entries, err := os.ReadDir(tc.cacheDir)
173	if err != nil {
174		return fmt.Errorf("failed to read cache directory: %w", err)
175	}
176
177	cutoff := time.Now().Add(-maxAge)
178
179	for _, entry := range entries {
180		if !entry.IsDir() && filepath.Ext(entry.Name()) == ".json" {
181			cachePath := filepath.Join(tc.cacheDir, entry.Name())
182
183			// Check file modification time
184			info, err := entry.Info()
185			if err != nil {
186				continue
187			}
188
189			if info.ModTime().Before(cutoff) {
190				os.Remove(cachePath)
191			}
192		}
193	}
194
195	return nil
196}
197
198// GetCacheInfo returns information about the cache directory
199func (tc *TaskCache) GetCacheInfo() (string, int, int64, error) {
200	entries, err := os.ReadDir(tc.cacheDir)
201	if err != nil {
202		return tc.cacheDir, 0, 0, fmt.Errorf("failed to read cache directory: %w", err)
203	}
204
205	var totalSize int64
206	fileCount := 0
207
208	for _, entry := range entries {
209		if !entry.IsDir() && filepath.Ext(entry.Name()) == ".json" {
210			info, err := entry.Info()
211			if err == nil {
212				totalSize += info.Size()
213				fileCount++
214			}
215		}
216	}
217
218	return tc.cacheDir, fileCount, totalSize, nil
219}
220
221// ClearCache removes all cached task results
222func (tc *TaskCache) ClearCache() error {
223	entries, err := os.ReadDir(tc.cacheDir)
224	if err != nil {
225		return fmt.Errorf("failed to read cache directory: %w", err)
226	}
227
228	for _, entry := range entries {
229		if !entry.IsDir() && filepath.Ext(entry.Name()) == ".json" {
230			cachePath := filepath.Join(tc.cacheDir, entry.Name())
231			if err := os.Remove(cachePath); err != nil {
232				return fmt.Errorf("failed to remove cache file %s: %w", entry.Name(), err)
233			}
234		}
235	}
236
237	return nil
238}