Harden AI session storage paths

This commit is contained in:
rcourtman 2026-03-28 13:50:55 +00:00
parent f88d89622a
commit d5b4850715
2 changed files with 108 additions and 7 deletions

View file

@ -1,6 +1,8 @@
package chat
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
@ -66,9 +68,30 @@ func NewSessionStore(dataDir string) (*SessionStore, error) {
// sessionPath returns the file path for a session
func (s *SessionStore) sessionPath(id string) string {
return filepath.Join(s.dataDir, hashedSessionStorageName(id)+".json")
}
func (s *SessionStore) legacySessionPath(id string) string {
return filepath.Join(s.dataDir, id+".json")
}
func hashedSessionStorageName(id string) string {
sum := sha256.Sum256([]byte(id))
return hex.EncodeToString(sum[:])
}
func resolveSessionPath(primaryPath, legacyPath string) string {
if _, err := os.Stat(primaryPath); err == nil {
return primaryPath
}
if legacyPath != "" {
if _, err := os.Stat(legacyPath); err == nil {
return legacyPath
}
}
return primaryPath
}
// List returns all sessions, sorted by updated_at descending
func (s *SessionStore) List() ([]Session, error) {
s.mu.RLock()
@ -85,11 +108,16 @@ func (s *SessionStore) List() ([]Session, error) {
continue
}
data, err := s.readSession(strings.TrimSuffix(entry.Name(), ".json"))
file, err := os.ReadFile(filepath.Join(s.dataDir, entry.Name()))
if err != nil {
log.Warn().Err(err).Str("file", entry.Name()).Msg("Failed to read session file")
continue
}
var data sessionData
if err := json.Unmarshal(file, &data); err != nil {
log.Warn().Err(err).Str("file", entry.Name()).Msg("Failed to parse session file")
continue
}
sessions = append(sessions, Session{
ID: data.ID,
@ -158,12 +186,20 @@ func (s *SessionStore) Delete(id string) error {
s.mu.Lock()
defer s.mu.Unlock()
path := s.sessionPath(id)
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("session not found: %s", id)
primaryPath := s.sessionPath(id)
legacyPath := s.legacySessionPath(id)
var removed bool
for _, path := range []string{primaryPath, legacyPath} {
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
continue
}
return fmt.Errorf("failed to delete session: %w", err)
}
return fmt.Errorf("failed to delete session: %w", err)
removed = true
}
if !removed {
return fmt.Errorf("session not found: %s", id)
}
// Also clean up resolved context, FSM, and knowledge accumulator
@ -239,7 +275,7 @@ func (s *SessionStore) UpdateLastMessage(id string, msg Message) error {
// readSession reads a session from disk (caller must hold lock)
func (s *SessionStore) readSession(id string) (*sessionData, error) {
path := s.sessionPath(id)
path := resolveSessionPath(s.sessionPath(id), s.legacySessionPath(id))
file, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
@ -267,6 +303,10 @@ func (s *SessionStore) writeSession(data sessionData) error {
if err := os.WriteFile(path, file, 0600); err != nil {
return fmt.Errorf("failed to write session: %w", err)
}
legacyPath := s.legacySessionPath(data.ID)
if legacyPath != path {
_ = os.Remove(legacyPath)
}
return nil
}

View file

@ -1,8 +1,10 @@
package chat
import (
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
"time"
@ -117,3 +119,62 @@ func TestGenerateTitle(t *testing.T) {
assert.Equal(t, tt.expected, generateTitle(tt.input))
}
}
func TestSessionStore_HashedPathsAndLegacyCompatibility(t *testing.T) {
store, err := NewSessionStore(t.TempDir())
require.NoError(t, err)
session := sessionData{
ID: "legacy-session",
Title: "Legacy Title",
Messages: []Message{{Role: "user", Content: "hello"}},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
path := store.sessionPath(session.ID)
assert.Equal(t, filepath.Join(store.dataDir, hashedSessionStorageName(session.ID)+".json"), path)
assert.NotContains(t, filepath.Base(path), "..")
legacyPath := store.legacySessionPath(session.ID)
raw, err := json.Marshal(session)
require.NoError(t, err)
require.NoError(t, os.WriteFile(legacyPath, raw, 0600))
got, err := store.Get(session.ID)
require.NoError(t, err)
assert.Equal(t, session.ID, got.ID)
assert.Equal(t, session.Title, got.Title)
sessions, err := store.List()
require.NoError(t, err)
require.Len(t, sessions, 1)
assert.Equal(t, session.ID, sessions[0].ID)
}
func TestSessionStore_PathTraversalIDsStayWithinStore(t *testing.T) {
store, err := NewSessionStore(t.TempDir())
require.NoError(t, err)
err = store.writeSession(sessionData{
ID: "..",
Title: "Traversal",
Messages: []Message{},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
require.NoError(t, err)
path := store.sessionPath("..")
rel, err := filepath.Rel(store.dataDir, path)
require.NoError(t, err)
assert.False(t, strings.HasPrefix(rel, ".."))
got, err := store.Get("..")
require.NoError(t, err)
assert.Equal(t, "..", got.ID)
require.NoError(t, store.Delete(".."))
_, err = store.Get("..")
assert.Error(t, err)
}