/ internal / memory / bundle.go
bundle.go
  1  package memory
  2  
  3  import (
  4  	"context"
  5  	"crypto/sha256"
  6  	"encoding/hex"
  7  	"encoding/json"
  8  	"fmt"
  9  	"io"
 10  	"net/http"
 11  	"net/url"
 12  	"os"
 13  	"path/filepath"
 14  	"sort"
 15  	"strings"
 16  	"syscall"
 17  	"time"
 18  )
 19  
 20  type ManifestFile struct {
 21  	Path   string `json:"path"`
 22  	Size   int64  `json:"size"`
 23  	Sha256 string `json:"sha256"`
 24  }
 25  
 26  type Manifest struct {
 27  	BundleTs        string         `json:"bundle_ts"`
 28  	BundleVersion   string         `json:"bundle_version"`
 29  	SizeBytes       int64          `json:"size_bytes"`
 30  	IntegritySha256 string         `json:"integrity_sha256"`
 31  	Files           []ManifestFile `json:"files"`
 32  }
 33  
 34  // Puller drives the periodic bundle download cycle in cloud mode.
 35  //   - cfg supplies BundleRoot, Endpoint, APIKey.
 36  //   - sidecar may be nil; reload notification is wired in a later task.
 37  //   - audit may be nil; events are dropped silently when so.
 38  type Puller struct {
 39  	cfg     Config
 40  	sidecar *Sidecar
 41  	audit   AuditLogger
 42  	httpc   *http.Client
 43  }
 44  
 45  func NewPuller(cfg Config, sidecar *Sidecar, audit AuditLogger) *Puller {
 46  	return &Puller{
 47  		cfg:     cfg,
 48  		sidecar: sidecar,
 49  		audit:   audit,
 50  		httpc:   &http.Client{Timeout: 60 * time.Second},
 51  	}
 52  }
 53  
 54  // versionInRange enforces [0.4.0, 0.5.0). Hand-rolled (no semver dep) — the
 55  // constraint is fixed and trivially encodable as integer triplets.
 56  func versionInRange(v string) bool {
 57  	parts := strings.SplitN(v, ".", 3)
 58  	if len(parts) != 3 {
 59  		return false
 60  	}
 61  	var maj, min, pat int
 62  	if _, err := fmt.Sscanf(v, "%d.%d.%d", &maj, &min, &pat); err != nil {
 63  		return false
 64  	}
 65  	if maj != 0 {
 66  		return false
 67  	}
 68  	if min != 4 {
 69  		return false
 70  	}
 71  	return pat >= 0
 72  }
 73  
 74  // tick is one iteration of the bundle pull cycle. Steps 1-4 implemented here:
 75  //  1. flock the bundle root (silent skip on contention)
 76  //  2. tenant fingerprint check (wipe local bundles on switch)
 77  //  3. fetch manifest + version range gate
 78  //  4. compare bundle_ts against the current symlink target (no-op if same)
 79  //
 80  // Steps 5-8 perform sandboxed download, SHA256 verification, atomic install,
 81  // reload, and retention.
 82  func (p *Puller) tick(ctx context.Context) error {
 83  	// Step 1: flock
 84  	if err := os.MkdirAll(p.cfg.BundleRoot, 0o700); err != nil {
 85  		return err
 86  	}
 87  	lockPath := filepath.Join(p.cfg.BundleRoot, "bundle.lock")
 88  	f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0o600)
 89  	if err != nil {
 90  		return err
 91  	}
 92  	defer f.Close()
 93  	if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil {
 94  		// Contention: another caller is mid-pull; we'll get the next tick.
 95  		return nil
 96  	}
 97  	defer syscall.Flock(int(f.Fd()), syscall.LOCK_UN)
 98  
 99  	// Step 2: tenant check (cloud-only — caller ensures provider==cloud
100  	// before invoking tick).
101  	switched, err := DetectTenantSwitch(p.cfg.BundleRoot, p.cfg.APIKey)
102  	if err != nil {
103  		return err
104  	}
105  	if switched {
106  		if err := os.RemoveAll(filepath.Join(p.cfg.BundleRoot, "bundles")); err != nil {
107  			return err
108  		}
109  		_ = os.Remove(filepath.Join(p.cfg.BundleRoot, "current"))
110  		if err := WriteFingerprint(p.cfg.BundleRoot, p.cfg.APIKey); err != nil {
111  			return err
112  		}
113  		if p.audit != nil {
114  			p.audit.Log("memory_tenant_switch", map[string]any{})
115  		}
116  	}
117  
118  	// Step 3: fetch manifest + version gate
119  	mf, err := p.fetchManifest(ctx)
120  	if err != nil {
121  		return err
122  	}
123  	if !versionInRange(mf.BundleVersion) {
124  		return fmt.Errorf("bundle_version %q outside [0.4.0, 0.5.0)", mf.BundleVersion)
125  	}
126  
127  	// Step 4: compare ts
128  	if cur := p.currentTs(); cur != "" && mf.BundleTs <= cur {
129  		return nil
130  	}
131  
132  	return p.installBundle(ctx, mf)
133  }
134  
135  func (p *Puller) fetchManifest(ctx context.Context) (*Manifest, error) {
136  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.cfg.Endpoint+"/api/v1/memory/bundle/manifest", nil)
137  	if err != nil {
138  		return nil, err
139  	}
140  	if p.cfg.APIKey != "" {
141  		req.Header.Set("X-API-Key", p.cfg.APIKey)
142  	}
143  	resp, err := p.httpc.Do(req)
144  	if err != nil {
145  		return nil, err
146  	}
147  	defer resp.Body.Close()
148  	if resp.StatusCode != 200 {
149  		body, _ := io.ReadAll(resp.Body)
150  		return nil, fmt.Errorf("manifest status %d: %s", resp.StatusCode, body)
151  	}
152  	var mf Manifest
153  	if err := json.NewDecoder(resp.Body).Decode(&mf); err != nil {
154  		return nil, err
155  	}
156  	return &mf, nil
157  }
158  
159  // currentTs reads the symlink target for <bundleRoot>/current and returns its
160  // basename (the bundle ts). Empty string if the symlink is absent or unreadable.
161  func (p *Puller) currentTs() string {
162  	target, err := os.Readlink(filepath.Join(p.cfg.BundleRoot, "current"))
163  	if err != nil {
164  		return ""
165  	}
166  	return filepath.Base(target)
167  }
168  
169  func (p *Puller) installBundle(ctx context.Context, mf *Manifest) error {
170  	staging := filepath.Join(p.cfg.BundleRoot, "staging", mf.BundleTs)
171  	if err := os.MkdirAll(staging, 0o700); err != nil {
172  		return err
173  	}
174  	cleanup := func() { _ = os.RemoveAll(staging) }
175  
176  	for _, f := range mf.Files {
177  		if err := validateManifestPath(f.Path, staging); err != nil {
178  			sample := f.Path
179  			if len(sample) > 64 {
180  				sample = sample[:64]
181  			}
182  			if p.audit != nil {
183  				p.audit.Log("memory_bundle_unsafe_path", map[string]any{
184  					"path_sample": sample,
185  					"reason":      err.Error(),
186  				})
187  			}
188  			cleanup()
189  			return fmt.Errorf("unsafe manifest path %q: %w", sample, err)
190  		}
191  		if err := p.downloadFile(ctx, mf.BundleTs, f, staging); err != nil {
192  			cleanup()
193  			return err
194  		}
195  	}
196  
197  	if err := p.atomicInstall(staging, mf.BundleTs); err != nil {
198  		cleanup()
199  		return err
200  	}
201  	if err := p.reloadSidecar(ctx); err != nil && p.audit != nil {
202  		p.audit.Log("memory_reload_failed", map[string]any{"reason": err.Error()})
203  	}
204  	p.retain(3)
205  	return nil
206  }
207  
208  // validateManifestPath enforces the path-sandboxing rules from spec §4.2:
209  // reject empty, null bytes, absolute paths, parent traversal, and any path
210  // that escapes the staging dir after Clean+Join. This MUST run before any
211  // network I/O so a malicious manifest cannot trigger a download to an
212  // unauthorized location.
213  func validateManifestPath(rel string, stagingDir string) error {
214  	if rel == "" {
215  		return fmt.Errorf("empty path")
216  	}
217  	if strings.ContainsRune(rel, 0) {
218  		return fmt.Errorf("null byte in path")
219  	}
220  	if strings.HasPrefix(rel, "/") {
221  		return fmt.Errorf("absolute path")
222  	}
223  	cleaned := filepath.Clean(rel)
224  	if cleaned == ".." || strings.HasPrefix(cleaned, "../") || strings.Contains(cleaned, string(os.PathSeparator)+"..") {
225  		return fmt.Errorf("contains parent traversal")
226  	}
227  	abs := filepath.Join(stagingDir, cleaned)
228  	cleanedAbs := filepath.Clean(abs)
229  	prefix := filepath.Clean(stagingDir) + string(os.PathSeparator)
230  	if !strings.HasPrefix(cleanedAbs+string(os.PathSeparator), prefix) {
231  		return fmt.Errorf("escapes staging dir")
232  	}
233  	return nil
234  }
235  
236  // downloadFile streams one manifest file into the (already-sandboxed) staging
237  // path while computing its SHA256. Mismatch → return error and let the caller
238  // clean staging.
239  func (p *Puller) downloadFile(ctx context.Context, ts string, f ManifestFile, staging string) error {
240  	target := filepath.Join(staging, filepath.Clean(f.Path))
241  	if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil {
242  		return err
243  	}
244  	escapedTS := url.PathEscape(ts)
245  	escapedPath := escapedManifestPath(f.Path)
246  	fullURL := strings.TrimSuffix(p.cfg.Endpoint, "/") + "/api/v1/memory/bundle/" + escapedTS + "/" + escapedPath
247  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
248  	if err != nil {
249  		return err
250  	}
251  	if p.cfg.APIKey != "" {
252  		req.Header.Set("X-API-Key", p.cfg.APIKey)
253  	}
254  	resp, err := p.httpc.Do(req)
255  	if err != nil {
256  		return err
257  	}
258  	defer resp.Body.Close()
259  	if resp.StatusCode != 200 {
260  		return fmt.Errorf("file %s status %d", f.Path, resp.StatusCode)
261  	}
262  	out, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
263  	if err != nil {
264  		return err
265  	}
266  	defer out.Close()
267  	h := sha256.New()
268  	if _, err := io.Copy(io.MultiWriter(out, h), resp.Body); err != nil {
269  		return err
270  	}
271  	got := hex.EncodeToString(h.Sum(nil))
272  	if got != f.Sha256 {
273  		if p.audit != nil {
274  			p.audit.Log("memory_bundle_install_failed", map[string]any{
275  				"reason":      "sha256_mismatch",
276  				"path_sample": f.Path,
277  			})
278  		}
279  		return fmt.Errorf("sha256 mismatch on %s: got %s want %s", f.Path, got, f.Sha256)
280  	}
281  	return nil
282  }
283  
284  func escapedManifestPath(raw string) string {
285  	parts := strings.Split(filepath.ToSlash(raw), "/")
286  	out := make([]string, 0, len(parts))
287  	for _, part := range parts {
288  		out = append(out, url.PathEscape(part))
289  	}
290  	return strings.Join(out, "/")
291  }
292  
293  // atomicInstall renames the staging dir into bundles/<ts> and atomically
294  // swaps the `current` symlink. Both rename + symlink-swap are POSIX-atomic
295  // on the same filesystem.
296  func (p *Puller) atomicInstall(stagingDir, ts string) error {
297  	bundlesDir := filepath.Join(p.cfg.BundleRoot, "bundles")
298  	if err := os.MkdirAll(bundlesDir, 0o700); err != nil {
299  		return err
300  	}
301  	finalDir := filepath.Join(bundlesDir, ts)
302  	if err := os.Rename(stagingDir, finalDir); err != nil {
303  		return fmt.Errorf("rename staging→bundle: %w", err)
304  	}
305  	tmpLink := filepath.Join(p.cfg.BundleRoot, "current.tmp")
306  	_ = os.Remove(tmpLink)
307  	if err := os.Symlink(finalDir, tmpLink); err != nil {
308  		return fmt.Errorf("symlink current.tmp: %w", err)
309  	}
310  	if err := os.Rename(tmpLink, filepath.Join(p.cfg.BundleRoot, "current")); err != nil {
311  		return fmt.Errorf("swap current symlink: %w", err)
312  	}
313  	return nil
314  }
315  
316  // reloadSidecar pings the sidecar's /bundle/reload endpoint via UDS so it
317  // picks up the new symlink target immediately. On 409 (reload_in_progress)
318  // retries once after 1s. Other failures are non-fatal — the sidecar's own
319  // poller will pick up the new bundle eventually.
320  func (p *Puller) reloadSidecar(ctx context.Context) error {
321  	if p.sidecar == nil {
322  		return nil
323  	}
324  	c := NewClient(p.cfg.SocketPath, 5*time.Second)
325  	_, err := c.Reload(ctx)
326  	if err != nil && strings.Contains(err.Error(), "reload_in_progress") {
327  		select {
328  		case <-ctx.Done():
329  			return ctx.Err()
330  		case <-time.After(1 * time.Second):
331  		}
332  		_, err = c.Reload(ctx)
333  	}
334  	return err
335  }
336  
337  // retain keeps the newest `keep` bundle dirs by ts plus the current symlink
338  // target (defensive). Best-effort — failures logged but not fatal.
339  func (p *Puller) retain(keep int) {
340  	bundlesDir := filepath.Join(p.cfg.BundleRoot, "bundles")
341  	entries, err := os.ReadDir(bundlesDir)
342  	if err != nil {
343  		return
344  	}
345  	var dirs []string
346  	for _, e := range entries {
347  		if e.IsDir() {
348  			dirs = append(dirs, e.Name())
349  		}
350  	}
351  	sort.Sort(sort.Reverse(sort.StringSlice(dirs)))
352  	if len(dirs) <= keep {
353  		return
354  	}
355  	currentTarget := p.currentTs()
356  	keepSet := map[string]bool{}
357  	for i, d := range dirs {
358  		if i < keep {
359  			keepSet[d] = true
360  		}
361  	}
362  	if currentTarget != "" {
363  		keepSet[currentTarget] = true
364  	}
365  	for _, d := range dirs {
366  		if !keepSet[d] {
367  			_ = os.RemoveAll(filepath.Join(bundlesDir, d))
368  		}
369  	}
370  }