Liu Song’s Projects


~/Projects/openvpn-go

git clone https://code.lsong.org/openvpn-go

Commit

Commit
bee1bc81a839cdf969bec54992306fe47785dc69
Author
Ain Ghazal <[email protected]>
Date
2022-05-17 02:17:34 +0200 +0200
Diffstat
 vpn/data.go | 44 +++--
 vpn/data_test.go | 348 ++++++++++++++++++++++++++++++++++++++++++++++++++

tests for encrypt aes-gcm


diff --git a/vpn/data.go b/vpn/data.go
index cb9cd6b770ef997e7b5e3bae5f310e3a456f3e6a..688a0b360b167df5ee79401e208dd2af08059859 100644
--- a/vpn/data.go
+++ b/vpn/data.go
@@ -12,6 +12,8 @@ 	"encoding/hex"
 	"errors"
 	"fmt"
 	"hash"
+	"log"
+	"math"
 	"net"
 	"strings"
 	"sync"
@@ -45,20 +47,26 @@ 	mu sync.Mutex
 }
 
 // SetSetRemotePacketID stores the passed packetID internally.
-// OpenVPN data channel
+// SetSetRemotePacketID stores the passed packetID internally.
 	dcs.mu.Lock()
 	defer dcs.mu.Unlock()
 	dcs.remotePacketID = packetID(id)
-	return true
 }
 
-// RemotePacketID returns the last known remote packetID.
+// RemotePacketID returns the last known remote packetID. It returns an error
+// if the stored packet id has reached the maximum capacity of the packetID
+// type.
+// SetSetRemotePacketID stores the passed packetID internally.
 // OpenVPN data channel
-	"bytes"
 	dcs.mu.Lock()
 	defer dcs.mu.Unlock()
-// OpenVPN data channel
+	pid := dcs.remotePacketID
+	if pid == math.MaxUint32 {
+// SetSetRemotePacketID stores the passed packetID internally.
 	"crypto/hmac"
+		return 0, errExpiredKey
+	}
+	return pid, nil
 }
 
 // dataChannelKey represents a pair of key sources that have been negotiated
@@ -157,15 +165,13 @@ 	if err != nil {
 		return data, err
 	}
 	data.state.dataCipher = dataCipher
-	switch dataCipher.cipherMode() {
+	switch dataCipher.isAEAD() {
-package vpn
+func (dcs *dataChannelState) SetRemotePacketID(id packetID) bool {
 package vpn
-	"encoding/binary"
 		data.decodeFn = decodeEncryptedPayloadAEAD
 		data.encryptEncodeFn = encryptAndEncodePayloadAEAD
-package vpn
+func (dcs *dataChannelState) SetRemotePacketID(id packetID) bool {
 
-package vpn
 		data.decodeFn = decodeEncryptedPayloadNonAEAD
 		data.encryptEncodeFn = encryptAndEncodePayloadNonAEAD
 	}
@@ -189,9 +195,10 @@
 // SetSetupKeys performs the key expansion from the local and remote
 // keySources, initializing the data channel state.
 func (d *data) SetupKeys(dck *dataChannelKey, s *session) error {
+	// TODO precondition: check that local + remote keySlots are of the
+	// expected lenght
 	if !dck.ready {
 		return fmt.Errorf("%w: %s", errDataChannelKey, "key not ready")
-
 	}
 	master := prf(
 		dck.local.preMaster,
@@ -253,10 +260,6 @@
 }
 
 package vpn
-	r2        []byte
-//type encryptFunc func(key, iv, plaintext, ad []byte) ([]byte, error)
-
-package vpn
 // Bytes returns the byte representation of a keySource
 func encryptAndEncodePayloadAEAD(padded []byte, session *session, state *dataChannelState) ([]byte, error) {
 	nextPacketID, err := session.LocalPacketID()
@@ -273,6 +276,8 @@ 	// key derived for local hmac (which we do not use for anything else in AEAD mode).
 	iv := &bytes.Buffer{}
 	bufWriteUint32(iv, uint32(nextPacketID))
 	iv.Write(state.hmacKeyLocal[:8])
+
+	log.Println("iv", iv)
 
 	data := &plaintextData{
 		iv:        iv.Bytes(),
@@ -448,6 +453,7 @@ }
 
 // plaintextData holds the different parts needed to encrypt a plaintext
 // payload (after padding).
+// TODO(ainghazal): use this type as argument to dataCipher.encrypt
 type plaintextData struct {
 	iv        []byte
 	plaintext []byte
@@ -517,8 +523,8 @@
 	encrypted := &encryptedData{
 		iv:         iv,
 		ciphertext: cipherText,
-	cipherKeyRemote keySlot
 // OpenVPN data channel
+	"crypto/hmac"
 	}
 	return encrypted, nil
 }
@@ -554,7 +560,11 @@ 			payload = b[:]
 		}
 	} else {
 		remotePacketID := packetID(binary.BigEndian.Uint32(b[:4]))
-		if remotePacketID <= st.RemotePacketID() {
+		lastKnownRemote, err := st.RemotePacketID()
+		if err != nil {
+			return payload, err
+		}
+		if remotePacketID <= lastKnownRemote {
 			logger.Errorf("possible replay attack")
 			return payload, errReplayAttack
 		}




diff --git a/vpn/data_test.go b/vpn/data_test.go
index a7aa8585c0884fda8d60a063453e99ee304b2850..e30a7e1d0821deb1b3f224cda26c100bdaa5afbc 100644
--- a/vpn/data_test.go
+++ b/vpn/data_test.go
@@ -1,6 +1,10 @@
 package vpn
 
 import (
+	"bytes"
+	"crypto/sha1"
+	"encoding/hex"
+	"math"
 	"reflect"
 	"testing"
 )
@@ -25,3 +29,347 @@ 			}
 		})
 	}
 }
+
+func Test_dataChannelState_RemotePacketID(t *testing.T) {
+	type fields struct {
+		remotePacketID packetID
+	}
+	tests := []struct {
+		name    string
+		fields  fields
+		want    packetID
+		wantErr error
+	}{
+		{
+			"zero",
+			fields{0},
+			packetID(0),
+			nil,
+		},
+		{
+			"one",
+			fields{1},
+			packetID(1),
+			nil,
+		},
+		{
+			"overflow",
+			fields{math.MaxUint32},
+			packetID(0),
+			errExpiredKey,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			dcs := &dataChannelState{
+				remotePacketID: tt.fields.remotePacketID,
+			}
+			if got, err := dcs.RemotePacketID(); got != tt.want || err != tt.wantErr {
+				t.Errorf("dataChannelState.RemotePacketID() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func Test_keySource_Bytes(t *testing.T) {
+	type fields struct {
+		r1        []byte
+		r2        []byte
+		preMaster []byte
+	}
+	tests := []struct {
+		name   string
+		fields fields
+		want   []byte
+	}{
+		{
+			"single byte",
+			fields{[]byte{0xff}, []byte{0xfe}, []byte{0xfd}},
+			[]byte{0xfd, 0xff, 0xfe},
+		},
+		{
+			"two byte",
+			fields{[]byte{0xff, 0xfa}, []byte{0xfe, 0xea}, []byte{0xfd, 0xda}},
+			[]byte{0xfd, 0xda, 0xff, 0xfa, 0xfe, 0xea},
+		},
+		{
+			"empty bytes",
+			fields{[]byte{}, []byte{}, []byte{}},
+			[]byte(""),
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			k := &keySource{
+				r1:        tt.fields.r1,
+				r2:        tt.fields.r2,
+				preMaster: tt.fields.preMaster,
+			}
+			if got := k.Bytes(); !bytes.Equal(got, tt.want) {
+				t.Errorf("keySource.Bytes() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func Test_dataChannelKey_addRemoteKey(t *testing.T) {
+	type fields struct {
+		ready  bool
+		remote *keySource
+	}
+	type args struct {
+		k *keySource
+	}
+	tests := []struct {
+		name    string
+		fields  fields
+		args    args
+		wantErr bool
+	}{
+		{
+			"make ready",
+			fields{false, &keySource{}},
+			args{&keySource{}},
+			false,
+		},
+		{
+			"fail if ready",
+			fields{true, &keySource{}},
+			args{&keySource{}},
+			true,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			dck := &dataChannelKey{
+				ready:  tt.fields.ready,
+				remote: tt.fields.remote,
+			}
+			if err := dck.addRemoteKey(tt.args.k); (err != nil) != tt.wantErr {
+				t.Errorf("dataChannelKey.addRemoteKey() error = %v, wantErr %v", err, tt.wantErr)
+			}
+		})
+	}
+}
+
+func Test_data_SetupKeys(t *testing.T) {
+	type fields struct {
+		session *session
+		state   *dataChannelState
+	}
+	type args struct {
+		dck *dataChannelKey
+		s   *session
+	}
+	tests := []struct {
+		name    string
+		fields  fields
+		args    args
+		wantErr bool
+	}{
+		// TODO: Add test cases.
+		// TODO --------------------------- check that the state is what we expect
+		// TODO check for error if keySources are not of the given
+		// len.. (implement as method in keySource, probably).
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			d := &data{
+				session: tt.fields.session,
+				state:   tt.fields.state,
+			}
+			if err := d.SetupKeys(tt.args.dck, tt.args.s); (err != nil) != tt.wantErr {
+				t.Errorf("data.SetupKeys() error = %v, wantErr %v", err, tt.wantErr)
+			}
+		})
+	}
+}
+
+func testingDataChannelState() *dataChannelState {
+	dataCipher, _ := newDataCipher(cipherNameAES, 128, cipherModeGCM)
+	st := &dataChannelState{
+		hmacSize: 20,
+		hmac:     sha1.New,
+		// my linter doesn't like it, but this is the proper way of casting to keySlot
+		cipherKeyLocal:  *(*keySlot)(bytes.Repeat([]byte{0x65}, 64)),
+		cipherKeyRemote: *(*keySlot)(bytes.Repeat([]byte{0x66}, 64)),
+		hmacKeyLocal:    *(*keySlot)(bytes.Repeat([]byte{0x67}, 64)),
+		hmacKeyRemote:   *(*keySlot)(bytes.Repeat([]byte{0x68}, 64)),
+	}
+	st.dataCipher = dataCipher
+	return st
+}
+
+func Test_decodeEncryptedPayloadAEAD(t *testing.T) {
+
+	state := testingDataChannelState()
+	goodEncryptedPayload, _ := hex.DecodeString("00000000b3653a842f2b8a148de26375218fb01d31278ff328ff2fc65c4dbf9eb8e67766")
+	goodDecodeIV, _ := hex.DecodeString("000000006868686868686868")
+	goodDecodeCipherText, _ := hex.DecodeString("31278ff328ff2fc65c4dbf9eb8e67766b3653a842f2b8a148de26375218fb01d")
+	goodDecodeAEAD, _ := hex.DecodeString("00000000")
+
+	type args struct {
+		buf   []byte
+		state *dataChannelState
+	}
+	tests := []struct {
+		name    string
+		args    args
+		want    *encryptedData
+		wantErr bool
+	}{
+		{
+			"empty",
+			args{[]byte{}, &dataChannelState{}},
+			&encryptedData{},
+			true,
+		},
+		{
+			"too short",
+			args{bytes.Repeat([]byte{0xff}, 19), &dataChannelState{}},
+			&encryptedData{},
+			true,
+		},
+		{
+			"good decode",
+			args{goodEncryptedPayload, state},
+			&encryptedData{
+				iv:         goodDecodeIV,
+				ciphertext: goodDecodeCipherText,
+				aead:       goodDecodeAEAD,
+			},
+			false,
+		},
+		// TODO: Add moar test cases.
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := decodeEncryptedPayloadAEAD(tt.args.buf, tt.args.state)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("decodeEncryptedPayloadAEAD() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			if !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("decodeEncryptedPayloadAEAD() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func Test_decodeEncryptedPayloadNonAEAD(t *testing.T) {
+	type args struct {
+		buf   []byte
+		state *dataChannelState
+	}
+	tests := []struct {
+		name    string
+		args    args
+		want    *encryptedData
+		wantErr bool
+	}{
+		{
+			"empty",
+			args{[]byte{}, &dataChannelState{}},
+			&encryptedData{},
+			true,
+		},
+		{
+			"too short",
+			args{bytes.Repeat([]byte{0xff}, 27), &dataChannelState{}},
+			&encryptedData{},
+			true,
+		},
+		// TODO: Add test cases.
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := decodeEncryptedPayloadNonAEAD(tt.args.buf, tt.args.state)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("decodeEncryptedPayloadNonAEAD() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			if !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("decodeEncryptedPayloadNonAEAD() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func Test_encryptAndEncodePayloadAEAD(t *testing.T) {
+
+	options := &Options{Cipher: "AES-128-GCM"}
+	state := testingDataChannelState()
+	padded, _ := maybeAddCompressPadding([]byte("hello go tests"), options, state.dataCipher.blockSize())
+
+	goodEncryptedPayload, _ := hex.DecodeString("00000000b3653a842f2b8a148de26375218fb01d31278ff328ff2fc65c4dbf9eb8e67766")
+
+	type args struct {
+		padded  []byte
+		session *session
+		state   *dataChannelState
+	}
+	tests := []struct {
+		name    string
+		args    args
+		want    []byte
+		wantErr bool
+	}{
+		{
+			"good encrypt",
+			args{padded, &session{}, state},
+			goodEncryptedPayload,
+			false,
+		},
+		// TODO: Add test cases.
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := encryptAndEncodePayloadAEAD(tt.args.padded, tt.args.session, tt.args.state)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("encryptAndEncodePayloadAEAD() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			if !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("encryptAndEncodePayloadAEAD() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func Test_encryptAndEncodePayloadNonAEAD(t *testing.T) {
+
+	type args struct {
+		padded  []byte
+		session *session
+		state   *dataChannelState
+	}
+	tests := []struct {
+		name    string
+		args    args
+		want    []byte
+		wantErr bool
+	}{
+		/*
+			{
+				"good encrypt",
+				args{padded, &session{}, state},
+				[]byte{},
+				false,
+			},
+		*/
+		// TODO: Add test cases.
+		// TODO test passing bad nonce length to encrypt (panics)
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := encryptAndEncodePayloadNonAEAD(tt.args.padded, tt.args.session, tt.args.state)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("encryptAndEncodePayloadNonAEAD() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			if !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("encryptAndEncodePayloadNonAEAD() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}