safing-jess/session.go
2023-09-26 13:26:44 +02:00

561 lines
16 KiB
Go

package jess
import (
"errors"
"fmt"
"hash"
"strings"
"github.com/safing/jess/hashtools"
"github.com/safing/jess/tools"
)
// Session holds session information for operations using the envelope it was initialized with.
type Session struct {
envelope *Envelope
DefaultSymmetricKeySize int
SecurityLevel int
maxSecurityLevel int
toolRequirements *Requirements
// session over the wire
wire *WireSession
// instances
all []tools.ToolLogic
toolsWithState []tools.ToolLogic
kdf tools.ToolLogic
passDerivator tools.ToolLogic
keyExchangers []tools.ToolLogic
keyEncapsulators []tools.ToolLogic
integratedCiphers []tools.ToolLogic
ciphers []tools.ToolLogic
managedMACHashers map[string]*managedHasher
macs []tools.ToolLogic
managedSigningHashers map[string]*managedHasher
signers []tools.ToolLogic
}
type managedHasher struct {
tool *hashtools.HashTool
hash hash.Hash
}
// Sum returns the hash sum of the managed hasher.
func (sh *managedHasher) Sum() ([]byte, error) {
if sh == nil || sh.hash == nil {
return nil, errors.New("managed hasher is broken")
}
return sh.hash.Sum(nil), nil
}
func newSession(e *Envelope) (*Session, error) { //nolint:maintidx
if e.suite == nil {
return nil, errors.New("suite not loaded")
}
// create session
s := &Session{
envelope: e,
toolRequirements: newEmptyRequirements(),
}
// check envelope security level
if e.SecurityLevel > 0 {
err := s.checkSecurityLevel(e.SecurityLevel, func() string {
return fmt.Sprintf(`envelope "%s"`, e.Name)
})
if err != nil {
return nil, err
}
}
// check suite security level
err := s.checkSecurityLevel(e.suite.SecurityLevel, func() string {
return fmt.Sprintf(`suite "%s"`, e.suite.ID)
})
if err != nil {
return nil, err
}
// prepare variables
var (
keySourceAvailable bool
totalSignetsSeen int
requireSecurityLevel bool
requireDefaultKeySize bool
)
// tool init loop: start
for i, toolID := range s.envelope.suite.Tools {
// ====================================
// tool init loop: check for duplicates
// ====================================
for j, dupeToolID := range s.envelope.suite.Tools {
if i != j && toolID == dupeToolID {
return nil, fmt.Errorf("cannot use tool %s twice, each tool may be only specified once", toolID)
}
}
// ===================================
// tool init loop: parse, prep and get
// ===================================
var (
hashTool *hashtools.HashTool
hashSumFn func() ([]byte, error)
)
// parse ID for args
var arg string
if strings.Contains(toolID, "(") {
splitted := strings.Split(toolID, "(")
toolID = splitted[0]
arg = strings.Trim(splitted[1], "()")
}
// get tool
tool, err := tools.Get(toolID)
if err != nil {
return nil, fmt.Errorf("the specified tool %s could not be found", toolID)
}
// create logic instance and add to logic and state lists
logic := tool.Factory()
s.all = append(s.all, logic)
if tool.Info.HasOption(tools.OptionHasState) {
s.toolsWithState = append(s.toolsWithState, logic)
}
// ===========================================================
// tool init loop: assign tools to queues and add requirements
// ===========================================================
switch tool.Info.Purpose {
case tools.PurposeKeyDerivation:
if s.kdf != nil {
return nil, fmt.Errorf("cannot use %s, you may only specify one key derivation tool and %s was already specified", tool.Info.Name, s.kdf.Info().Name)
}
s.kdf = logic
case tools.PurposePassDerivation:
if s.passDerivator != nil {
return nil, fmt.Errorf("cannot use %s, you may only specify one password derivation tool and %s was already specified", tool.Info.Name, s.passDerivator.Info().Name)
}
s.passDerivator = logic
s.toolRequirements.Add(SenderAuthentication)
s.toolRequirements.Add(RecipientAuthentication)
case tools.PurposeKeyExchange:
s.keyExchangers = append(s.keyExchangers, logic)
s.toolRequirements.Add(RecipientAuthentication)
case tools.PurposeKeyEncapsulation:
s.keyEncapsulators = append(s.keyEncapsulators, logic)
s.toolRequirements.Add(RecipientAuthentication)
case tools.PurposeSigning:
s.signers = append(s.signers, logic)
s.toolRequirements.Add(Integrity)
s.toolRequirements.Add(SenderAuthentication)
case tools.PurposeIntegratedCipher:
s.integratedCiphers = append(s.integratedCiphers, logic)
s.toolRequirements.Add(Confidentiality)
s.toolRequirements.Add(Integrity)
case tools.PurposeCipher:
s.ciphers = append(s.ciphers, logic)
s.toolRequirements.Add(Confidentiality)
case tools.PurposeMAC:
s.macs = append(s.macs, logic)
s.toolRequirements.Add(Integrity)
}
// ============================================
// tool init loop: process options, get hashers
// ============================================
for _, option := range tool.Info.Options {
switch option {
case tools.OptionStreaming:
// TODO: Implementation pending.
case tools.OptionNeedsManagedHasher:
// get managed hasher list
var managedHashers map[string]*managedHasher
switch tool.Info.Purpose {
case tools.PurposeMAC:
if s.managedMACHashers == nil {
s.managedMACHashers = make(map[string]*managedHasher)
}
managedHashers = s.managedMACHashers
case tools.PurposeSigning:
if s.managedSigningHashers == nil {
s.managedSigningHashers = make(map[string]*managedHasher)
}
managedHashers = s.managedSigningHashers
default:
return nil, fmt.Errorf("only MAC and Signing tools may use managed hashers")
}
// get or assign a new managed hasher
mngdHasher, ok := managedHashers[arg]
if !ok {
// get hashtool
ht, err := hashtools.Get(arg)
if err != nil {
return nil, fmt.Errorf("the specified hashtool for %s(%s) could not be found", toolID, arg)
}
// save to managed hashers
mngdHasher = &managedHasher{
tool: ht,
hash: ht.New(),
}
managedHashers[arg] = mngdHasher
}
hashTool = mngdHasher.tool
hashSumFn = mngdHasher.Sum
case tools.OptionNeedsDedicatedHasher:
hashTool, err = hashtools.Get(arg)
if err != nil {
return nil, fmt.Errorf("the specified hashtool for %s(%s) could not be found", toolID, arg)
}
case tools.OptionNeedsSecurityLevel:
requireSecurityLevel = true
case tools.OptionNeedsDefaultKeySize:
requireDefaultKeySize = true
}
}
// ===============================
// tool init loop: initialize tool
// ===============================
// init tool
logic.Init(
tool,
&Helper{
session: s,
info: tool.Info,
},
hashTool,
hashSumFn,
)
// ==============================================
// tool init loop: calc and check security levels
// ==============================================
err = s.calcAndCheckSecurityLevel(logic, nil)
if err != nil {
return nil, err
}
// ==========================================
// tool init loop: calculate default key size
// ==========================================
// find biggest key size for default
if tool.Info.KeySize > s.DefaultSymmetricKeySize {
s.DefaultSymmetricKeySize = tool.Info.KeySize
}
} // tool init loop: end
// =======================================================
// calc and check signet security levels, default key size
// =======================================================
for _, tool := range s.all {
var err error
var seen int
// calc and check signet security levels
switch tool.Info().Purpose {
case tools.PurposePassDerivation:
//nolint:scopelint // function is executed immediately within loop
err = e.LoopSecrets(SignetSchemePassword, func(signet *Signet) error {
seen++
return s.calcAndCheckSecurityLevel(tool, signet)
})
keySourceAvailable = true
case tools.PurposeKeyExchange,
tools.PurposeKeyEncapsulation:
//nolint:scopelint // function is executed immediately within loop
err = e.LoopRecipients(tool.Info().Name, func(signet *Signet) error {
seen++
return s.calcAndCheckSecurityLevel(tool, signet)
})
keySourceAvailable = true
case tools.PurposeSigning:
//nolint:scopelint // function is executed immediately within loop
err = e.LoopSenders(tool.Info().Name, func(signet *Signet) error {
seen++
return s.calcAndCheckSecurityLevel(tool, signet)
})
keySourceAvailable = true
default:
continue
}
// check error
if err != nil {
return nil, err
}
// check if anything is here
if seen == 0 {
return nil, fmt.Errorf("tool %s requires at least one signet", tool.Info().Name)
}
totalSignetsSeen += seen
}
// key signets
err = e.LoopSecrets(SignetSchemeKey, func(signet *Signet) error {
s.toolRequirements.Add(SenderAuthentication)
s.toolRequirements.Add(RecipientAuthentication)
totalSignetsSeen++
keySourceAvailable = true
return s.calcAndCheckSecurityLevel(nil, signet)
})
if err != nil {
return nil, err
}
// ======================================================
// check security level and default key size requirements
// ======================================================
// apply manual security level
if minimumSecurityLevel > 0 && minimumSecurityLevel > s.SecurityLevel {
s.SecurityLevel = minimumSecurityLevel
}
// apply manual key size
if minimumSymmetricKeySize > 0 && minimumSymmetricKeySize > s.DefaultSymmetricKeySize {
s.DefaultSymmetricKeySize = minimumSymmetricKeySize
}
// check security level requirement
if requireSecurityLevel && s.SecurityLevel == 0 {
return nil, fmt.Errorf("this toolset requires the security level to be set manually")
}
// check default key size requirement
if requireDefaultKeySize && s.DefaultSymmetricKeySize == 0 {
return nil, fmt.Errorf("this toolset requires the default key size to be set manually")
}
// ============
// final checks
// ============
// check requirements
if s.toolRequirements.Empty() {
return nil, errors.New("envelope excludes all security requirements, no meaningful operation possible")
}
err = s.toolRequirements.CheckComplianceTo(s.envelope.suite.Provides)
if err != nil {
return nil, err
}
// check if we have recipient auth without confidentiality
if s.toolRequirements.Has(RecipientAuthentication) &&
!s.toolRequirements.Has(Confidentiality) {
return nil, errors.New("having recipient authentication without confidentiality does not make sense")
}
// check if we have confidentiality without integrity
if s.toolRequirements.Has(Confidentiality) &&
!s.toolRequirements.Has(Integrity) {
return nil, errors.New("having confidentiality without integrity does not make sense")
}
// check if we are missing a kdf, but need one
if s.kdf == nil && len(s.signers) != len(s.envelope.suite.Tools) {
return nil, errors.New("missing a key derivation tool")
}
// check if have a kdf, even if we don't need one
if len(s.integratedCiphers) == 0 &&
len(s.ciphers) == 0 &&
len(s.macs) == 0 &&
s.kdf != nil {
return nil, errors.New("key derivation tool specified, but not needed")
}
// check if we have a key source
if !keySourceAvailable &&
(s.toolRequirements.Has(Integrity) || s.toolRequirements.Has(Confidentiality)) {
return nil, errors.New("missing key source, please add a tool that provides a key or add a key signet directly")
}
// check if there are unused signets
if len(s.envelope.Secrets)+
len(s.envelope.Senders)+
len(s.envelope.Recipients) > totalSignetsSeen {
return nil, fmt.Errorf("detected signet or recipient in envelope that is not used by any tool")
}
// check session security level
// while this should never result in an error (because every part was already checked separately) this is used as a precaution to catch errors in future code changes
err = s.checkSecurityLevel(s.SecurityLevel, func() string {
return "current session"
})
if err != nil {
return nil, err
}
return s, nil
}
//nolint:gocognit
func (s *Session) calcAndCheckSecurityLevel(logic tools.ToolLogic, signet *Signet) error {
// get signet scheme
signetScheme := ""
if signet != nil {
signetScheme = signet.Scheme
}
var err error
var calculatedSecurityLevel int
switch {
case signetScheme == SignetSchemeKey:
calculatedSecurityLevel = len(signet.Key) * 8
case signetScheme == SignetSchemePassword && signet != nil:
// only check if present
// existence check is done when opening/closing
if len(signet.Key) > 0 {
switch logic.Info().Name {
case "SCRYPT-20":
// TODO: integrate this into the tool interface
calculatedSecurityLevel = CalculatePasswordSecurityLevel(string(signet.Key), 1<<20)
case "PBKDF2-SHA2-256":
// TODO: integrate this into the tool interface
calculatedSecurityLevel = CalculatePasswordSecurityLevel(string(signet.Key), 20000)
default:
calculatedSecurityLevel = CalculatePasswordSecurityLevel(string(signet.Key), 1)
}
if calculatedSecurityLevel < 0 {
return fmt.Errorf(`supplied password signet "%s" is exceptionally weak and should not be used`, signet.ID)
}
}
default:
// get tool security level
if signet == nil {
// nil interface hackery for inherited SecurityLevel() functions
calculatedSecurityLevel, err = logic.SecurityLevel(nil)
} else {
calculatedSecurityLevel, err = logic.SecurityLevel(signet)
}
if err != nil {
return err
}
}
if calculatedSecurityLevel == 0 {
// not applicable
return nil
}
if calculatedSecurityLevel < 0 {
// broken!
if signet != nil {
return fmt.Errorf(`supplied %s signet "%s" is considered broken and should not be used anymore`, signet.Scheme, signet.ID)
}
return fmt.Errorf(`tool %s is considered broken and should not be used anymore`, logic.Info().Name)
}
if signet != nil {
// signet based security level checks
err = s.checkSecurityLevel(calculatedSecurityLevel, func() string {
return fmt.Sprintf(`supplied %s signet "%s"`, signet.Scheme, signet.ID)
})
} else {
// tool based securty level checks
err = s.checkSecurityLevel(calculatedSecurityLevel, func() string {
return "tool %s" + logic.Info().Name
})
}
if err != nil {
return err
}
// adapt security level of session
// lower session security level
if s.SecurityLevel == 0 || calculatedSecurityLevel < s.SecurityLevel {
s.SecurityLevel = calculatedSecurityLevel
}
// raise session max security level
if calculatedSecurityLevel > s.maxSecurityLevel {
s.maxSecurityLevel = calculatedSecurityLevel
}
return nil
}
func (s *Session) checkSecurityLevel(levelToCheck int, subject func() string) error {
switch {
case minimumSecurityLevel > 0:
// check against minimumSecurityLevel
// (overrides other checks)
if levelToCheck < minimumSecurityLevel {
return fmt.Errorf(
`%s with a security level of %d is weaker than the desired security level of %d`,
subject(),
levelToCheck,
minimumSecurityLevel,
)
}
case s.envelope.SecurityLevel > 0:
// check against envelope's minimum security level
if levelToCheck < s.envelope.SecurityLevel {
return fmt.Errorf(
`%s with a security level of %d is weaker than the envelope's minimum security level of %d`,
subject(),
levelToCheck,
s.envelope.SecurityLevel,
)
}
case levelToCheck < defaultSecurityLevel:
// check against default security level as fallback
return fmt.Errorf(
`%s with a security level of %d is weaker than the default minimum security level of %d`,
subject(),
levelToCheck,
defaultSecurityLevel,
)
}
return nil
}
// NonceSize returns the nonce size to use for new letters.
func (s *Session) NonceSize() int {
size := s.maxSecurityLevel / 32
if size < 4 {
size = 4
}
return size
}