package jess import ( "errors" "fmt" ) // msgNo uint64 // lastReKeyAtMsgNo uint64 // // sendKeyCarryover []byte // recvKeyCarryover []byte // // handshakeState uint8 // eKXSignetPairs [][2]*Signet // eKESignets []*Signet // newKey []byte func (w *WireSession) sendHandshakeAndInitKDF(letter *Letter) error { var err error var keyMaterial [][]byte var burn bool // process handshake switch w.handshakeState { case wireStateInit: // client keyMaterial, err = w.session.setupClosingKeyMaterial(letter) if err != nil { return fmt.Errorf("failed to setup initial sending handshake key material: %w", err) } fallthrough case wireStateIdle: // client and server if w.msgNo == 0 || (!w.server && w.reKeyNeeded()) { err = w.generateLocalKeyExchangeSignets(letter) if err != nil { return fmt.Errorf("failed to generate local key exchange signets for initiating handshake: %w", err) } err = w.generateLocalKeyEncapsulationSignets(letter) if err != nil { return fmt.Errorf("failed to generate local key encapsulation signets for initiating handshake: %w", err) } w.handshakeState = wireStateAwaitKey } case wireStateSendKey: // server err = w.generateLocalKeyExchangeSignets(letter) if err != nil { return fmt.Errorf("failed to generate local key exchange signets for completing handshake: %w", err) } // debugging: /* fmt.Println("key states:") for _, kxPair := range w.eKXSignets { fmt.Printf("kxPair: %+v\n", kxPair) fmt.Printf("signet: %+v\n", kxPair.signet) fmt.Printf("peer: %+v\n", kxPair.peer) } for _, kePair := range w.eKESignets { fmt.Printf("kePair: %+v\n", kePair) } */ keyMaterial, err = w.makeSharedKeys(keyMaterial) if err != nil { return fmt.Errorf("failed to create shared keys for completing handshake: %w", err) } err = w.generateLocalKeyEncapsulationSignets(letter) if err != nil { return fmt.Errorf("failed to generate local key encapsulation signets for completing handshake: %w", err) } keyMaterial, err = w.makeAndEncapsulateNewKeys(letter, keyMaterial) if err != nil { return fmt.Errorf("failed to encapsulate keys for completing handshake: %w", err) } w.newKeyMaterial = copyKeyMaterial(keyMaterial) w.handshakeState = wireStatsAwaitApply case wireStateSendApply: // client keyMaterial = append(keyMaterial, w.newKeyMaterial...) letter.ApplyKeys = true burn = true } // add carryover key if w.msgNo == 0 { if w.session.DefaultSymmetricKeySize == 0 { return fmt.Errorf("missing default key size") } w.sendKeyCarryover = make([]byte, w.session.DefaultSymmetricKeySize) } else { keyMaterial = append(keyMaterial, w.sendKeyCarryover) } // init KDF err = w.session.kdf.InitKeyDerivation(letter.Nonce, keyMaterial...) if err != nil { return fmt.Errorf("failed to init %s kdf: %w", w.session.kdf.Info().Name, err) } // derive new carryover key err = w.session.kdf.DeriveKeyWriteTo(w.sendKeyCarryover) if err != nil { return fmt.Errorf("failed to iterate session key with %s: %w", w.session.kdf.Info().Name, err) } if w.msgNo == 0 { // copy initial sendkey to recvkey w.recvKeyCarryover = make([]byte, len(w.sendKeyCarryover)) copy(w.recvKeyCarryover, w.sendKeyCarryover) } // increase msg counter w.msgNo++ // burn and return if burn { return w.burnEphemeralKeys() } return nil } //nolint:gocognit func (w *WireSession) recvHandshakeAndInitKDF(letter *Letter) error { var err error var keyMaterial [][]byte var burn bool // process handshake switch w.handshakeState { case wireStateInit: // server keyMaterial, err = w.session.setupOpeningKeyMaterial(letter) if err != nil { return fmt.Errorf("failed to setup initial receiving handshake key material: %w", err) } fallthrough case wireStateIdle: // server if len(letter.Keys) > 0 { // apply keys to pairs // check if there are the right amount of keys if len(w.session.keyEncapsulators) == 0 { // TODO: // initial wire handshake is special: // key encapsulators send two seals in the initial handshake messages // one of them is added to the recipients // the other is a new ephermal key if len(letter.Keys) != len(w.eKXSignets)+len(w.eKESignets) { return errors.New("failed to setup initial receiving handshake: incorrect amount of keys in letter") } } // assign keys to kx/ke pairs keyIndex := 0 for _, kxPair := range w.eKXSignets { kxPair.peer = &Signet{ Version: letter.Version, Key: letter.Keys[keyIndex].Value, Public: true, tool: kxPair.tool.Definition(), } keyIndex++ } for _, kePair := range w.eKESignets { // skip keys with ID for letter.Keys[keyIndex].ID != "" { keyIndex++ } kePair.signet = &Signet{ Version: letter.Version, Key: letter.Keys[keyIndex].Value, Public: true, tool: kePair.tool.Definition(), } keyIndex++ } w.handshakeState = wireStateSendKey } case wireStateAwaitKey: // client if len(letter.Keys) > 0 { // apply keys to pairs // check if there are the right amount of keys if len(letter.Keys) != len(w.eKXSignets)+len(w.eKESignets) { return errors.New("incorrect amount of keys in letter") } // assign keys to kx/ke pairs keyIndex := 0 for _, kxPair := range w.eKXSignets { kxPair.peer = &Signet{ Version: letter.Version, Key: letter.Keys[keyIndex].Value, Public: true, tool: kxPair.tool.Definition(), } keyIndex++ } for _, kePair := range w.eKESignets { kePair.seal = letter.Keys[keyIndex] keyIndex++ } // make shared keys keyMaterial, err = w.makeSharedKeys(keyMaterial) if err != nil { return err } // unwrap keys keyMaterial, err = w.unwrapKeys(keyMaterial) if err != nil { return err } w.newKeyMaterial = copyKeyMaterial(keyMaterial) w.handshakeState = wireStateSendApply } case wireStatsAwaitApply: // server if letter.ApplyKeys { keyMaterial = append(keyMaterial, w.newKeyMaterial...) burn = true } } // add carryover key if w.msgNo == 0 { if w.session.DefaultSymmetricKeySize == 0 { return fmt.Errorf("missing default key size") } w.recvKeyCarryover = make([]byte, w.session.DefaultSymmetricKeySize) } else { keyMaterial = append(keyMaterial, w.recvKeyCarryover) } // init KDF err = w.session.kdf.InitKeyDerivation(letter.Nonce, keyMaterial...) if err != nil { return fmt.Errorf("failed to init %s kdf: %w", w.session.kdf.Info().Name, err) } // derive new carryover key err = w.session.kdf.DeriveKeyWriteTo(w.recvKeyCarryover) if err != nil { return fmt.Errorf("failed to iterate session key with %s: %w", w.session.kdf.Info().Name, err) } if w.msgNo == 0 { // copy initial recvkey to sendkey w.sendKeyCarryover = make([]byte, len(w.recvKeyCarryover)) copy(w.sendKeyCarryover, w.recvKeyCarryover) } // increase msg counter w.msgNo++ // burn and return if burn { return w.burnEphemeralKeys() } return nil } func (w *WireSession) generateLocalKeyExchangeSignets(letter *Letter) (err error) { for _, kxp := range w.eKXSignets { if kxp.signet == nil { // generate signet kxp.signet = NewSignetBase(kxp.tool.Definition()) err := kxp.signet.GenerateKey() if err != nil { return err } // store signet err = kxp.signet.StoreKey() if err != nil { return err } // add to letter rcpt, err := kxp.signet.AsRecipient() // convert to public signet if err != nil { return err } err = rcpt.StoreKey() if err != nil { return err } letter.Keys = append(letter.Keys, &Seal{ Value: rcpt.Key, }) } } return nil } func (w *WireSession) makeSharedKeys(keyMaterial [][]byte) ([][]byte, error) { for _, kxp := range w.eKXSignets { // check signet if kxp.signet == nil { return nil, fmt.Errorf("missing key exchange signet for %s", kxp.tool.Info().Name) } // check peer signet if kxp.peer == nil { return nil, fmt.Errorf("missing key exchange recipient/peer for %s", kxp.tool.Info().Name) } // load peer key err := kxp.peer.LoadKey() if err != nil { return nil, err } // make shared key sharedKey, err := kxp.tool.MakeSharedKey(kxp.signet, kxp.peer) if err != nil { return nil, err } // append key to material keyMaterial = append(keyMaterial, sharedKey) } return keyMaterial, nil } func (w *WireSession) generateLocalKeyEncapsulationSignets(letter *Letter) (err error) { for _, kep := range w.eKESignets { if kep.signet == nil { // generate signet kep.signet = NewSignetBase(kep.tool.Definition()) err := kep.signet.GenerateKey() if err != nil { return err } // store signet err = kep.signet.StoreKey() if err != nil { return err } // add to letter rcpt, err := kep.signet.AsRecipient() // convert to public signet if err != nil { return err } err = rcpt.StoreKey() if err != nil { return err } letter.Keys = append(letter.Keys, &Seal{ Value: rcpt.Key, }) } } return nil } func (w *WireSession) makeAndEncapsulateNewKeys(letter *Letter, keyMaterial [][]byte) ([][]byte, error) { for _, kep := range w.eKESignets { // check signet if kep.signet == nil { return nil, fmt.Errorf("missing key encapsulation signet for %s", kep.tool.Info().Name) } // load signet key err := kep.signet.LoadKey() if err != nil { return nil, err } // generate new key newKey, err := RandomBytes(w.session.DefaultSymmetricKeySize) if err != nil { return nil, err } // encapsulate it encapsulatedKey, err := kep.tool.EncapsulateKey(newKey, kep.signet) if err != nil { return nil, err } // add key to material and letter keyMaterial = append(keyMaterial, newKey) letter.Keys = append(letter.Keys, &Seal{Value: encapsulatedKey}) } return keyMaterial, nil } func (w *WireSession) unwrapKeys(keyMaterial [][]byte) ([][]byte, error) { for _, kep := range w.eKESignets { // check signet if kep.signet == nil { return nil, fmt.Errorf("missing key encapsulation signet for %s", kep.tool.Info().Name) } // check seal if kep.seal == nil { return nil, fmt.Errorf("missing key encapsulation seal for %s", kep.tool.Info().Name) } // unwrap key unwrappedKey, err := kep.tool.UnwrapKey(kep.seal.Value, kep.signet) if err != nil { return nil, err } // add key to material keyMaterial = append(keyMaterial, unwrappedKey) } return keyMaterial, nil } // burnEphemeralKeys burns all the ephemeral key material in the session. This is currently ineffective, see known issues in the project's README. func (w *WireSession) burnEphemeralKeys() error { var lastErr error // burn key exchange signets for _, entry := range w.eKXSignets { if entry.signet != nil { lastErr = entry.signet.Burn() } entry.signet = nil if entry.peer != nil { lastErr = entry.peer.Burn() } entry.peer = nil } // burn key encapsulation signets for _, entry := range w.eKESignets { if entry.signet != nil { lastErr = entry.signet.Burn() } entry.signet = nil if entry.seal != nil { Burn(entry.seal.Value) } entry.seal = nil } // burn new key material for _, part := range w.newKeyMaterial { Burn(part) } w.newKeyMaterial = nil return lastErr } func copyKeyMaterial(keyMaterial [][]byte) [][]byte { copied := make([][]byte, len(keyMaterial)) for index, part := range keyMaterial { copiedPart := make([]byte, len(part)) copy(copiedPart, part) copied[index] = copiedPart } return copied }