~/Projects/openvpn-go
git clone https://code.lsong.org/openvpn-go
Commit
- Commit
- bee1bc81a839cdf969bec54992306fe47785dc69
- Author
- Ain Ghazal <[email protected]>
- Date
- 2022-05-17 02:17:34 +0200 +0200
- Diffstat
vpn/data.go | 44 +++-- vpn/data_test.go | 348 ++++++++++++++++++++++++++++++++++++++++++++++++++
tests for encrypt aes-gcm
diff --git a/vpn/data.go b/vpn/data.go index cb9cd6b770ef997e7b5e3bae5f310e3a456f3e6a..688a0b360b167df5ee79401e208dd2af08059859 100644 --- a/vpn/data.go +++ b/vpn/data.go @@ -12,6 +12,8 @@ "encoding/hex" "errors" "fmt" "hash" + "log" + "math" "net" "strings" "sync" @@ -45,20 +47,26 @@ mu sync.Mutex } // SetSetRemotePacketID stores the passed packetID internally. -// OpenVPN data channel +// SetSetRemotePacketID stores the passed packetID internally. dcs.mu.Lock() defer dcs.mu.Unlock() dcs.remotePacketID = packetID(id) - return true } -// RemotePacketID returns the last known remote packetID. +// RemotePacketID returns the last known remote packetID. It returns an error +// if the stored packet id has reached the maximum capacity of the packetID +// type. +// SetSetRemotePacketID stores the passed packetID internally. // OpenVPN data channel - "bytes" dcs.mu.Lock() defer dcs.mu.Unlock() -// OpenVPN data channel + pid := dcs.remotePacketID + if pid == math.MaxUint32 { +// SetSetRemotePacketID stores the passed packetID internally. "crypto/hmac" + return 0, errExpiredKey + } + return pid, nil } // dataChannelKey represents a pair of key sources that have been negotiated @@ -157,15 +165,13 @@ if err != nil { return data, err } data.state.dataCipher = dataCipher - switch dataCipher.cipherMode() { + switch dataCipher.isAEAD() { -package vpn +func (dcs *dataChannelState) SetRemotePacketID(id packetID) bool { package vpn - "encoding/binary" data.decodeFn = decodeEncryptedPayloadAEAD data.encryptEncodeFn = encryptAndEncodePayloadAEAD -package vpn +func (dcs *dataChannelState) SetRemotePacketID(id packetID) bool { -package vpn data.decodeFn = decodeEncryptedPayloadNonAEAD data.encryptEncodeFn = encryptAndEncodePayloadNonAEAD } @@ -189,9 +195,10 @@ // SetSetupKeys performs the key expansion from the local and remote // keySources, initializing the data channel state. func (d *data) SetupKeys(dck *dataChannelKey, s *session) error { + // TODO precondition: check that local + remote keySlots are of the + // expected lenght if !dck.ready { return fmt.Errorf("%w: %s", errDataChannelKey, "key not ready") - } master := prf( dck.local.preMaster, @@ -253,10 +260,6 @@ } package vpn - r2 []byte -//type encryptFunc func(key, iv, plaintext, ad []byte) ([]byte, error) - -package vpn // Bytes returns the byte representation of a keySource func encryptAndEncodePayloadAEAD(padded []byte, session *session, state *dataChannelState) ([]byte, error) { nextPacketID, err := session.LocalPacketID() @@ -273,6 +276,8 @@ // key derived for local hmac (which we do not use for anything else in AEAD mode). iv := &bytes.Buffer{} bufWriteUint32(iv, uint32(nextPacketID)) iv.Write(state.hmacKeyLocal[:8]) + + log.Println("iv", iv) data := &plaintextData{ iv: iv.Bytes(), @@ -448,6 +453,7 @@ } // plaintextData holds the different parts needed to encrypt a plaintext // payload (after padding). +// TODO(ainghazal): use this type as argument to dataCipher.encrypt type plaintextData struct { iv []byte plaintext []byte @@ -517,8 +523,8 @@ encrypted := &encryptedData{ iv: iv, ciphertext: cipherText, - cipherKeyRemote keySlot // OpenVPN data channel + "crypto/hmac" } return encrypted, nil } @@ -554,7 +560,11 @@ payload = b[:] } } else { remotePacketID := packetID(binary.BigEndian.Uint32(b[:4])) - if remotePacketID <= st.RemotePacketID() { + lastKnownRemote, err := st.RemotePacketID() + if err != nil { + return payload, err + } + if remotePacketID <= lastKnownRemote { logger.Errorf("possible replay attack") return payload, errReplayAttack } diff --git a/vpn/data_test.go b/vpn/data_test.go index a7aa8585c0884fda8d60a063453e99ee304b2850..e30a7e1d0821deb1b3f224cda26c100bdaa5afbc 100644 --- a/vpn/data_test.go +++ b/vpn/data_test.go @@ -1,6 +1,10 @@ package vpn import ( + "bytes" + "crypto/sha1" + "encoding/hex" + "math" "reflect" "testing" ) @@ -25,3 +29,347 @@ } }) } } + +func Test_dataChannelState_RemotePacketID(t *testing.T) { + type fields struct { + remotePacketID packetID + } + tests := []struct { + name string + fields fields + want packetID + wantErr error + }{ + { + "zero", + fields{0}, + packetID(0), + nil, + }, + { + "one", + fields{1}, + packetID(1), + nil, + }, + { + "overflow", + fields{math.MaxUint32}, + packetID(0), + errExpiredKey, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dcs := &dataChannelState{ + remotePacketID: tt.fields.remotePacketID, + } + if got, err := dcs.RemotePacketID(); got != tt.want || err != tt.wantErr { + t.Errorf("dataChannelState.RemotePacketID() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_keySource_Bytes(t *testing.T) { + type fields struct { + r1 []byte + r2 []byte + preMaster []byte + } + tests := []struct { + name string + fields fields + want []byte + }{ + { + "single byte", + fields{[]byte{0xff}, []byte{0xfe}, []byte{0xfd}}, + []byte{0xfd, 0xff, 0xfe}, + }, + { + "two byte", + fields{[]byte{0xff, 0xfa}, []byte{0xfe, 0xea}, []byte{0xfd, 0xda}}, + []byte{0xfd, 0xda, 0xff, 0xfa, 0xfe, 0xea}, + }, + { + "empty bytes", + fields{[]byte{}, []byte{}, []byte{}}, + []byte(""), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &keySource{ + r1: tt.fields.r1, + r2: tt.fields.r2, + preMaster: tt.fields.preMaster, + } + if got := k.Bytes(); !bytes.Equal(got, tt.want) { + t.Errorf("keySource.Bytes() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_dataChannelKey_addRemoteKey(t *testing.T) { + type fields struct { + ready bool + remote *keySource + } + type args struct { + k *keySource + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + "make ready", + fields{false, &keySource{}}, + args{&keySource{}}, + false, + }, + { + "fail if ready", + fields{true, &keySource{}}, + args{&keySource{}}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dck := &dataChannelKey{ + ready: tt.fields.ready, + remote: tt.fields.remote, + } + if err := dck.addRemoteKey(tt.args.k); (err != nil) != tt.wantErr { + t.Errorf("dataChannelKey.addRemoteKey() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_data_SetupKeys(t *testing.T) { + type fields struct { + session *session + state *dataChannelState + } + type args struct { + dck *dataChannelKey + s *session + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + // TODO: Add test cases. + // TODO --------------------------- check that the state is what we expect + // TODO check for error if keySources are not of the given + // len.. (implement as method in keySource, probably). + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &data{ + session: tt.fields.session, + state: tt.fields.state, + } + if err := d.SetupKeys(tt.args.dck, tt.args.s); (err != nil) != tt.wantErr { + t.Errorf("data.SetupKeys() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func testingDataChannelState() *dataChannelState { + dataCipher, _ := newDataCipher(cipherNameAES, 128, cipherModeGCM) + st := &dataChannelState{ + hmacSize: 20, + hmac: sha1.New, + // my linter doesn't like it, but this is the proper way of casting to keySlot + cipherKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x65}, 64)), + cipherKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x66}, 64)), + hmacKeyLocal: *(*keySlot)(bytes.Repeat([]byte{0x67}, 64)), + hmacKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x68}, 64)), + } + st.dataCipher = dataCipher + return st +} + +func Test_decodeEncryptedPayloadAEAD(t *testing.T) { + + state := testingDataChannelState() + goodEncryptedPayload, _ := hex.DecodeString("00000000b3653a842f2b8a148de26375218fb01d31278ff328ff2fc65c4dbf9eb8e67766") + goodDecodeIV, _ := hex.DecodeString("000000006868686868686868") + goodDecodeCipherText, _ := hex.DecodeString("31278ff328ff2fc65c4dbf9eb8e67766b3653a842f2b8a148de26375218fb01d") + goodDecodeAEAD, _ := hex.DecodeString("00000000") + + type args struct { + buf []byte + state *dataChannelState + } + tests := []struct { + name string + args args + want *encryptedData + wantErr bool + }{ + { + "empty", + args{[]byte{}, &dataChannelState{}}, + &encryptedData{}, + true, + }, + { + "too short", + args{bytes.Repeat([]byte{0xff}, 19), &dataChannelState{}}, + &encryptedData{}, + true, + }, + { + "good decode", + args{goodEncryptedPayload, state}, + &encryptedData{ + iv: goodDecodeIV, + ciphertext: goodDecodeCipherText, + aead: goodDecodeAEAD, + }, + false, + }, + // TODO: Add moar test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := decodeEncryptedPayloadAEAD(tt.args.buf, tt.args.state) + if (err != nil) != tt.wantErr { + t.Errorf("decodeEncryptedPayloadAEAD() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("decodeEncryptedPayloadAEAD() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_decodeEncryptedPayloadNonAEAD(t *testing.T) { + type args struct { + buf []byte + state *dataChannelState + } + tests := []struct { + name string + args args + want *encryptedData + wantErr bool + }{ + { + "empty", + args{[]byte{}, &dataChannelState{}}, + &encryptedData{}, + true, + }, + { + "too short", + args{bytes.Repeat([]byte{0xff}, 27), &dataChannelState{}}, + &encryptedData{}, + true, + }, + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := decodeEncryptedPayloadNonAEAD(tt.args.buf, tt.args.state) + if (err != nil) != tt.wantErr { + t.Errorf("decodeEncryptedPayloadNonAEAD() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("decodeEncryptedPayloadNonAEAD() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_encryptAndEncodePayloadAEAD(t *testing.T) { + + options := &Options{Cipher: "AES-128-GCM"} + state := testingDataChannelState() + padded, _ := maybeAddCompressPadding([]byte("hello go tests"), options, state.dataCipher.blockSize()) + + goodEncryptedPayload, _ := hex.DecodeString("00000000b3653a842f2b8a148de26375218fb01d31278ff328ff2fc65c4dbf9eb8e67766") + + type args struct { + padded []byte + session *session + state *dataChannelState + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + { + "good encrypt", + args{padded, &session{}, state}, + goodEncryptedPayload, + false, + }, + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := encryptAndEncodePayloadAEAD(tt.args.padded, tt.args.session, tt.args.state) + if (err != nil) != tt.wantErr { + t.Errorf("encryptAndEncodePayloadAEAD() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("encryptAndEncodePayloadAEAD() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_encryptAndEncodePayloadNonAEAD(t *testing.T) { + + type args struct { + padded []byte + session *session + state *dataChannelState + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + /* + { + "good encrypt", + args{padded, &session{}, state}, + []byte{}, + false, + }, + */ + // TODO: Add test cases. + // TODO test passing bad nonce length to encrypt (panics) + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := encryptAndEncodePayloadNonAEAD(tt.args.padded, tt.args.session, tt.args.state) + if (err != nil) != tt.wantErr { + t.Errorf("encryptAndEncodePayloadNonAEAD() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("encryptAndEncodePayloadNonAEAD() = %v, want %v", got, tt.want) + } + }) + } +}