safing-jess/core-wire.go
2021-10-02 23:00:01 +02:00

466 lines
11 KiB
Go

package jess
import (
"errors"
"fmt"
)
// msgNo uint64
// lastReKeyAtMsgNo uint64
//
// sendKeyCarryover []byte
// recvKeyCarryover []byte
//
// handshakeState uint8
// eKXSignetPairs [][2]*Signet
// eKESignets []*Signet
// newKey []byte
func (w *WireSession) sendHandshakeAndInitKDF(letter *Letter) error {
var err error
var keyMaterial [][]byte
var burn bool
// process handshake
switch w.handshakeState {
case wireStateInit: // client
keyMaterial, err = w.session.setupClosingKeyMaterial(letter)
if err != nil {
return fmt.Errorf("failed to setup initial sending handshake key material: %w", err)
}
fallthrough
case wireStateIdle: // client and server
if w.msgNo == 0 || (!w.server && w.reKeyNeeded()) {
err = w.generateLocalKeyExchangeSignets(letter)
if err != nil {
return fmt.Errorf("failed to generate local key exchange signets for initiating handshake: %w", err)
}
err = w.generateLocalKeyEncapsulationSignets(letter)
if err != nil {
return fmt.Errorf("failed to generate local key encapsulation signets for initiating handshake: %w", err)
}
w.handshakeState = wireStateAwaitKey
}
case wireStateSendKey: // server
err = w.generateLocalKeyExchangeSignets(letter)
if err != nil {
return fmt.Errorf("failed to generate local key exchange signets for completing handshake: %w", err)
}
// debugging:
/*
fmt.Println("key states:")
for _, kxPair := range w.eKXSignets {
fmt.Printf("kxPair: %+v\n", kxPair)
fmt.Printf("signet: %+v\n", kxPair.signet)
fmt.Printf("peer: %+v\n", kxPair.peer)
}
for _, kePair := range w.eKESignets {
fmt.Printf("kePair: %+v\n", kePair)
}
*/
keyMaterial, err = w.makeSharedKeys(keyMaterial)
if err != nil {
return fmt.Errorf("failed to create shared keys for completing handshake: %w", err)
}
err = w.generateLocalKeyEncapsulationSignets(letter)
if err != nil {
return fmt.Errorf("failed to generate local key encapsulation signets for completing handshake: %w", err)
}
keyMaterial, err = w.makeAndEncapsulateNewKeys(letter, keyMaterial)
if err != nil {
return fmt.Errorf("failed to encapsulate keys for completing handshake: %w", err)
}
w.newKeyMaterial = copyKeyMaterial(keyMaterial)
w.handshakeState = wireStatsAwaitApply
case wireStateSendApply: // client
keyMaterial = append(keyMaterial, w.newKeyMaterial...)
letter.ApplyKeys = true
burn = true
}
// add carryover key
if w.msgNo == 0 {
if w.session.DefaultSymmetricKeySize == 0 {
return fmt.Errorf("missing default key size")
}
w.sendKeyCarryover = make([]byte, w.session.DefaultSymmetricKeySize)
} else {
keyMaterial = append(keyMaterial, w.sendKeyCarryover)
}
// init KDF
err = w.session.kdf.InitKeyDerivation(letter.Nonce, keyMaterial...)
if err != nil {
return fmt.Errorf("failed to init %s kdf: %w", w.session.kdf.Info().Name, err)
}
// derive new carryover key
err = w.session.kdf.DeriveKeyWriteTo(w.sendKeyCarryover)
if err != nil {
return fmt.Errorf("failed to iterate session key with %s: %w", w.session.kdf.Info().Name, err)
}
if w.msgNo == 0 {
// copy initial sendkey to recvkey
w.recvKeyCarryover = make([]byte, len(w.sendKeyCarryover))
copy(w.recvKeyCarryover, w.sendKeyCarryover)
}
// increase msg counter
w.msgNo++
// burn and return
if burn {
return w.burnEphemeralKeys()
}
return nil
}
//nolint:gocognit
func (w *WireSession) recvHandshakeAndInitKDF(letter *Letter) error {
var err error
var keyMaterial [][]byte
var burn bool
// process handshake
switch w.handshakeState {
case wireStateInit: // server
keyMaterial, err = w.session.setupOpeningKeyMaterial(letter)
if err != nil {
return fmt.Errorf("failed to setup initial receiving handshake key material: %w", err)
}
fallthrough
case wireStateIdle: // server
if len(letter.Keys) > 0 {
// apply keys to pairs
// check if there are the right amount of keys
if len(w.session.keyEncapsulators) == 0 {
// TODO:
// initial wire handshake is special:
// key encapsulators send two seals in the initial handshake messages
// one of them is added to the recipients
// the other is a new ephermal key
if len(letter.Keys) != len(w.eKXSignets)+len(w.eKESignets) {
return errors.New("failed to setup initial receiving handshake: incorrect amount of keys in letter")
}
}
// assign keys to kx/ke pairs
keyIndex := 0
for _, kxPair := range w.eKXSignets {
kxPair.peer = &Signet{
Version: letter.Version,
Key: letter.Keys[keyIndex].Value,
Public: true,
tool: kxPair.tool.Definition(),
}
keyIndex++
}
for _, kePair := range w.eKESignets {
// skip keys with ID
for letter.Keys[keyIndex].ID != "" {
keyIndex++
}
kePair.signet = &Signet{
Version: letter.Version,
Key: letter.Keys[keyIndex].Value,
Public: true,
tool: kePair.tool.Definition(),
}
keyIndex++
}
w.handshakeState = wireStateSendKey
}
case wireStateAwaitKey: // client
if len(letter.Keys) > 0 {
// apply keys to pairs
// check if there are the right amount of keys
if len(letter.Keys) != len(w.eKXSignets)+len(w.eKESignets) {
return errors.New("incorrect amount of keys in letter")
}
// assign keys to kx/ke pairs
keyIndex := 0
for _, kxPair := range w.eKXSignets {
kxPair.peer = &Signet{
Version: letter.Version,
Key: letter.Keys[keyIndex].Value,
Public: true,
tool: kxPair.tool.Definition(),
}
keyIndex++
}
for _, kePair := range w.eKESignets {
kePair.seal = letter.Keys[keyIndex]
keyIndex++
}
// make shared keys
keyMaterial, err = w.makeSharedKeys(keyMaterial)
if err != nil {
return err
}
// unwrap keys
keyMaterial, err = w.unwrapKeys(keyMaterial)
if err != nil {
return err
}
w.newKeyMaterial = copyKeyMaterial(keyMaterial)
w.handshakeState = wireStateSendApply
}
case wireStatsAwaitApply: // server
if letter.ApplyKeys {
keyMaterial = append(keyMaterial, w.newKeyMaterial...)
burn = true
}
}
// add carryover key
if w.msgNo == 0 {
if w.session.DefaultSymmetricKeySize == 0 {
return fmt.Errorf("missing default key size")
}
w.recvKeyCarryover = make([]byte, w.session.DefaultSymmetricKeySize)
} else {
keyMaterial = append(keyMaterial, w.recvKeyCarryover)
}
// init KDF
err = w.session.kdf.InitKeyDerivation(letter.Nonce, keyMaterial...)
if err != nil {
return fmt.Errorf("failed to init %s kdf: %w", w.session.kdf.Info().Name, err)
}
// derive new carryover key
err = w.session.kdf.DeriveKeyWriteTo(w.recvKeyCarryover)
if err != nil {
return fmt.Errorf("failed to iterate session key with %s: %w", w.session.kdf.Info().Name, err)
}
if w.msgNo == 0 {
// copy initial recvkey to sendkey
w.sendKeyCarryover = make([]byte, len(w.recvKeyCarryover))
copy(w.sendKeyCarryover, w.recvKeyCarryover)
}
// increase msg counter
w.msgNo++
// burn and return
if burn {
return w.burnEphemeralKeys()
}
return nil
}
func (w *WireSession) generateLocalKeyExchangeSignets(letter *Letter) (err error) {
for _, kxp := range w.eKXSignets {
if kxp.signet == nil {
// generate signet
kxp.signet = NewSignetBase(kxp.tool.Definition())
err := kxp.signet.GenerateKey()
if err != nil {
return err
}
// store signet
err = kxp.signet.StoreKey()
if err != nil {
return err
}
// add to letter
rcpt, err := kxp.signet.AsRecipient() // convert to public signet
if err != nil {
return err
}
err = rcpt.StoreKey()
if err != nil {
return err
}
letter.Keys = append(letter.Keys, &Seal{
Value: rcpt.Key,
})
}
}
return nil
}
func (w *WireSession) makeSharedKeys(keyMaterial [][]byte) ([][]byte, error) {
for _, kxp := range w.eKXSignets {
// check signet
if kxp.signet == nil {
return nil, fmt.Errorf("missing key exchange signet for %s", kxp.tool.Info().Name)
}
// check peer signet
if kxp.peer == nil {
return nil, fmt.Errorf("missing key exchange recipient/peer for %s", kxp.tool.Info().Name)
}
// load peer key
err := kxp.peer.LoadKey()
if err != nil {
return nil, err
}
// make shared key
sharedKey, err := kxp.tool.MakeSharedKey(kxp.signet, kxp.peer)
if err != nil {
return nil, err
}
// append key to material
keyMaterial = append(keyMaterial, sharedKey)
}
return keyMaterial, nil
}
func (w *WireSession) generateLocalKeyEncapsulationSignets(letter *Letter) (err error) {
for _, kep := range w.eKESignets {
if kep.signet == nil {
// generate signet
kep.signet = NewSignetBase(kep.tool.Definition())
err := kep.signet.GenerateKey()
if err != nil {
return err
}
// store signet
err = kep.signet.StoreKey()
if err != nil {
return err
}
// add to letter
rcpt, err := kep.signet.AsRecipient() // convert to public signet
if err != nil {
return err
}
err = rcpt.StoreKey()
if err != nil {
return err
}
letter.Keys = append(letter.Keys, &Seal{
Value: rcpt.Key,
})
}
}
return nil
}
func (w *WireSession) makeAndEncapsulateNewKeys(letter *Letter, keyMaterial [][]byte) ([][]byte, error) {
for _, kep := range w.eKESignets {
// check signet
if kep.signet == nil {
return nil, fmt.Errorf("missing key encapsulation signet for %s", kep.tool.Info().Name)
}
// load signet key
err := kep.signet.LoadKey()
if err != nil {
return nil, err
}
// generate new key
newKey, err := RandomBytes(w.session.DefaultSymmetricKeySize)
if err != nil {
return nil, err
}
// encapsulate it
encapsulatedKey, err := kep.tool.EncapsulateKey(newKey, kep.signet)
if err != nil {
return nil, err
}
// add key to material and letter
keyMaterial = append(keyMaterial, newKey)
letter.Keys = append(letter.Keys, &Seal{Value: encapsulatedKey})
}
return keyMaterial, nil
}
func (w *WireSession) unwrapKeys(keyMaterial [][]byte) ([][]byte, error) {
for _, kep := range w.eKESignets {
// check signet
if kep.signet == nil {
return nil, fmt.Errorf("missing key encapsulation signet for %s", kep.tool.Info().Name)
}
// check seal
if kep.seal == nil {
return nil, fmt.Errorf("missing key encapsulation seal for %s", kep.tool.Info().Name)
}
// unwrap key
unwrappedKey, err := kep.tool.UnwrapKey(kep.seal.Value, kep.signet)
if err != nil {
return nil, err
}
// add key to material
keyMaterial = append(keyMaterial, unwrappedKey)
}
return keyMaterial, nil
}
// burnEphemeralKeys burns all the ephemeral key material in the session. This is currently ineffective, see known issues in the project's README.
func (w *WireSession) burnEphemeralKeys() error {
var lastErr error
// burn key exchange signets
for _, entry := range w.eKXSignets {
if entry.signet != nil {
lastErr = entry.signet.Burn()
}
entry.signet = nil
if entry.peer != nil {
lastErr = entry.peer.Burn()
}
entry.peer = nil
}
// burn key encapsulation signets
for _, entry := range w.eKESignets {
if entry.signet != nil {
lastErr = entry.signet.Burn()
}
entry.signet = nil
if entry.seal != nil {
Burn(entry.seal.Value)
}
entry.seal = nil
}
// burn new key material
for _, part := range w.newKeyMaterial {
Burn(part)
}
w.newKeyMaterial = nil
return lastErr
}
func copyKeyMaterial(keyMaterial [][]byte) [][]byte {
copied := make([][]byte, len(keyMaterial))
for index, part := range keyMaterial {
copiedPart := make([]byte, len(part))
copy(copiedPart, part)
copied[index] = copiedPart
}
return copied
}