From 90ec9996cb6e7ea98ffeab1b6e28037d79e81026 Mon Sep 17 00:00:00 2001
From: Roland Shoemaker <roland@golang.org>
Date: Fri, 24 Jan 2025 14:08:03 -0800
Subject: [PATCH] crypto/pbkdf2: add keyLength limit

As specified by RFC 8018. Also prevent unexpected overflows on 32 bit
systems.

Change-Id: I50c4a177b7d1ebb15f9b3b96e515d93f19d3f68e
Reviewed-on: https://go-review.googlesource.com/c/go/+/644122
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Roland Shoemaker <roland@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Robert Griesemer <gri@google.com>
---
 src/crypto/internal/fips140/pbkdf2/pbkdf2.go | 20 ++++++++++++-
 src/crypto/pbkdf2/pbkdf2.go                  |  3 ++
 src/crypto/pbkdf2/pbkdf2_test.go             | 30 ++++++++++++++++++++
 3 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/src/crypto/internal/fips140/pbkdf2/pbkdf2.go b/src/crypto/internal/fips140/pbkdf2/pbkdf2.go
index 8f6d991504..05923f6826 100644
--- a/src/crypto/internal/fips140/pbkdf2/pbkdf2.go
+++ b/src/crypto/internal/fips140/pbkdf2/pbkdf2.go
@@ -7,15 +7,33 @@ package pbkdf2
 import (
 	"crypto/internal/fips140"
 	"crypto/internal/fips140/hmac"
+	"errors"
 )
 
+// divRoundUp divides x+y-1 by y, rounding up if the result is not whole.
+// This function casts x and y to int64 in order to avoid cases where
+// x+y would overflow int on systems where int is an int32. The result
+// is an int, which is safe as (x+y-1)/y should always fit, regardless
+// of the integer size.
+func divRoundUp(x, y int) int {
+	return int((int64(x) + int64(y) - 1) / int64(y))
+}
+
 func Key[Hash fips140.Hash](h func() Hash, password string, salt []byte, iter, keyLength int) ([]byte, error) {
 	setServiceIndicator(salt, keyLength)
 
+	if keyLength <= 0 {
+		return nil, errors.New("pkbdf2: keyLength must be larger than 0")
+	}
+
 	prf := hmac.New(h, []byte(password))
 	hmac.MarkAsUsedInKDF(prf)
 	hashLen := prf.Size()
-	numBlocks := (keyLength + hashLen - 1) / hashLen
+	numBlocks := divRoundUp(keyLength, hashLen)
+	const maxBlocks = int64(1<<32 - 1)
+	if keyLength+hashLen < keyLength || int64(numBlocks) > maxBlocks {
+		return nil, errors.New("pbkdf2: keyLength too long")
+	}
 
 	var buf [4]byte
 	dk := make([]byte, 0, numBlocks*hashLen)
diff --git a/src/crypto/pbkdf2/pbkdf2.go b/src/crypto/pbkdf2/pbkdf2.go
index 271d2b0331..dd5fc33f21 100644
--- a/src/crypto/pbkdf2/pbkdf2.go
+++ b/src/crypto/pbkdf2/pbkdf2.go
@@ -34,6 +34,9 @@ import (
 //
 // Using a higher iteration count will increase the cost of an exhaustive
 // search but will also make derivation proportionally slower.
+//
+// keyLength must be a positive integer between 1 and (2^32 - 1) * h.Size().
+// Setting keyLength to a value outside of this range will result in an error.
 func Key[Hash hash.Hash](h func() Hash, password string, salt []byte, iter, keyLength int) ([]byte, error) {
 	fh := fips140hash.UnwrapNew(h)
 	if fips140only.Enabled {
diff --git a/src/crypto/pbkdf2/pbkdf2_test.go b/src/crypto/pbkdf2/pbkdf2_test.go
index 03980c7e54..eb0ed14e24 100644
--- a/src/crypto/pbkdf2/pbkdf2_test.go
+++ b/src/crypto/pbkdf2/pbkdf2_test.go
@@ -221,3 +221,33 @@ func TestPBKDF2ServiceIndicator(t *testing.T) {
 		t.Error("FIPS service indicator should not be set")
 	}
 }
+
+func TestMaxKeyLength(t *testing.T) {
+	// This error cannot be triggered on platforms where int is 31 bits (i.e.
+	// 32-bit platforms), since the max value for keyLength is 1<<31-1 and
+	// 1<<31-1 * hLen will always be less than 1<<32-1 * hLen.
+	keySize := int64(1<<63 - 1)
+	if int64(int(keySize)) != keySize {
+		t.Skip("cannot be replicated on platforms where int is 31 bits")
+	}
+	_, err := pbkdf2.Key(sha256.New, "password", []byte("salt"), 1, int(keySize))
+	if err == nil {
+		t.Fatal("expected pbkdf2.Key to fail with extremely large keyLength")
+	}
+	keySize = int64(1<<32-1) * (sha256.Size + 1)
+	_, err = pbkdf2.Key(sha256.New, "password", []byte("salt"), 1, int(keySize))
+	if err == nil {
+		t.Fatal("expected pbkdf2.Key to fail with extremely large keyLength")
+	}
+}
+
+func TestZeroKeyLength(t *testing.T) {
+	_, err := pbkdf2.Key(sha256.New, "password", []byte("salt"), 1, 0)
+	if err == nil {
+		t.Fatal("expected pbkdf2.Key to fail with zero keyLength")
+	}
+	_, err = pbkdf2.Key(sha256.New, "password", []byte("salt"), 1, -1)
+	if err == nil {
+		t.Fatal("expected pbkdf2.Key to fail with negative keyLength")
+	}
+}