0
1
mirror of https://github.com/golang/go synced 2025-05-27 15:10:40 +00:00

crypto/internal/fips140/nistec: make SetBytes constant time

Similarly to CL 648035, SetBytes doesn't need to be constant time for
the uses we make of it in the standard library (ECDH and ECDSA public
keys), but it doesn't cost much to make it constant time for users of
the re-exported package, or even just to save the next person from
convincing themselves that it's ok for it not to be constant time.

Change-Id: I6a6a465622a0de08d9fc71db75c63185a82aa54a
Reviewed-on: https://go-review.googlesource.com/c/go/+/650579
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
This commit is contained in:
Filippo Valsorda
2025-02-19 22:41:59 +01:00
committed by Gopher Robot
parent d93f6df0cc
commit f24b299df2
8 changed files with 156 additions and 35 deletions

@ -308,6 +308,8 @@ var invalidPublicKeys = map[ecdh.Curve][]string{
// Points not on the curve.
"046b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c2964fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f6",
"0400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
// Non-canonical encoding.
"04ffffffff00000001000000000000000000000001000000000000000000000004ba6dbc4555a7e7fa016ec431667e8521ee35afc49b265c3accbea3f7cdb70433",
},
ecdh.P384(): {
// Bad lengths.
@ -322,6 +324,8 @@ var invalidPublicKeys = map[ecdh.Curve][]string{
// Points not on the curve.
"04aa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760ab73617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da3113b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e60",
"04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
// Non-canonical encoding.
"04fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff000000000000000100000001732152442fb6ee5c3e6ce1d920c059bc623563814d79042b903ce60f1d4487fccd450a86da03f3e6ed525d02017bfdb3",
},
ecdh.P521(): {
// Bad lengths.
@ -336,6 +340,8 @@ var invalidPublicKeys = map[ecdh.Curve][]string{
// Points not on the curve.
"0400c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f828af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf97e7e31c2e5bd66011839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088be94769fd16651",
"04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
// Non-canonical encoding.
"0402000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100d9254fdf800496acb33790b103c5ee9fac12832fe546c632225b0f7fce3da4574b1a879b623d722fa8fc34d5fc2a8731aad691a9a8bb8b554c95a051d6aa505acf",
},
ecdh.X25519(): {},
}

@ -223,13 +223,8 @@ func (e *{{ .Element }}) SetBytes(v []byte) (*{{ .Element }}, error) {
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new({{ .Element }}).Sub(
new({{ .Element }}), new({{ .Element }}).One()).Bytes()
for i := range v {
if v[i] < minusOneEncoding[i] {
break
}
if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid {{ .Element }} encoding")
}
if subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
return nil, errors.New("invalid {{ .Element }} encoding")
}
var in [{{ .Prefix }}ElementLen]byte

@ -78,13 +78,8 @@ func (e *P224Element) SetBytes(v []byte) (*P224Element, error) {
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new(P224Element).Sub(
new(P224Element), new(P224Element).One()).Bytes()
for i := range v {
if v[i] < minusOneEncoding[i] {
break
}
if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid P224Element encoding")
}
if subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
return nil, errors.New("invalid P224Element encoding")
}
var in [p224ElementLen]byte

@ -78,13 +78,8 @@ func (e *P256Element) SetBytes(v []byte) (*P256Element, error) {
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new(P256Element).Sub(
new(P256Element), new(P256Element).One()).Bytes()
for i := range v {
if v[i] < minusOneEncoding[i] {
break
}
if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid P256Element encoding")
}
if subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
return nil, errors.New("invalid P256Element encoding")
}
var in [p256ElementLen]byte

@ -78,13 +78,8 @@ func (e *P384Element) SetBytes(v []byte) (*P384Element, error) {
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new(P384Element).Sub(
new(P384Element), new(P384Element).One()).Bytes()
for i := range v {
if v[i] < minusOneEncoding[i] {
break
}
if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid P384Element encoding")
}
if subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
return nil, errors.New("invalid P384Element encoding")
}
var in [p384ElementLen]byte

@ -78,13 +78,8 @@ func (e *P521Element) SetBytes(v []byte) (*P521Element, error) {
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new(P521Element).Sub(
new(P521Element), new(P521Element).One()).Bytes()
for i := range v {
if v[i] < minusOneEncoding[i] {
break
}
if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid P521Element encoding")
}
if subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
return nil, errors.New("invalid P521Element encoding")
}
var in [p521ElementLen]byte

@ -4,6 +4,11 @@
package subtle
import (
"crypto/internal/fips140deps/byteorder"
"math/bits"
)
// ConstantTimeCompare returns 1 if the two slices, x and y, have equal contents
// and 0 otherwise. The time taken is a function of the length of the slices and
// is independent of the contents. If the lengths of x and y do not match it
@ -22,6 +27,37 @@ func ConstantTimeCompare(x, y []byte) int {
return ConstantTimeByteEq(v, 0)
}
// ConstantTimeLessOrEqBytes returns 1 if x <= y and 0 otherwise. The comparison
// is lexigraphical, or big-endian. The time taken is a function of the length of
// the slices and is independent of the contents. If the lengths of x and y do not
// match it returns 0 immediately.
func ConstantTimeLessOrEqBytes(x, y []byte) int {
if len(x) != len(y) {
return 0
}
// Do a constant time subtraction chain y - x.
// If there is no borrow at the end, then x <= y.
var b uint64
for len(x) > 8 {
x0 := byteorder.BEUint64(x[len(x)-8:])
y0 := byteorder.BEUint64(y[len(y)-8:])
_, b = bits.Sub64(y0, x0, b)
x = x[:len(x)-8]
y = y[:len(y)-8]
}
if len(x) > 0 {
xb := make([]byte, 8)
yb := make([]byte, 8)
copy(xb[8-len(x):], x)
copy(yb[8-len(y):], y)
x0 := byteorder.BEUint64(xb)
y0 := byteorder.BEUint64(yb)
_, b = bits.Sub64(y0, x0, b)
}
return int(b ^ 1)
}
// ConstantTimeSelect returns x if v == 1 and y if v == 0.
// Its behavior is undefined if v takes any other value.
func ConstantTimeSelect(v, x, y int) int { return ^(v-1)&x | (v-1)&y }

@ -0,0 +1,104 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package subtle
import (
"bytes"
"crypto/internal/fips140deps/byteorder"
"math/rand/v2"
"testing"
"time"
)
func TestConstantTimeLessOrEqBytes(t *testing.T) {
seed := make([]byte, 32)
byteorder.BEPutUint64(seed, uint64(time.Now().UnixNano()))
r := rand.NewChaCha8([32]byte(seed))
for l := range 20 {
a := make([]byte, l)
b := make([]byte, l)
empty := make([]byte, l)
r.Read(a)
r.Read(b)
exp := 0
if bytes.Compare(a, b) <= 0 {
exp = 1
}
if got := ConstantTimeLessOrEqBytes(a, b); got != exp {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want %d", a, b, got, exp)
}
exp = 0
if bytes.Compare(b, a) <= 0 {
exp = 1
}
if got := ConstantTimeLessOrEqBytes(b, a); got != exp {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want %d", b, a, got, exp)
}
if got := ConstantTimeLessOrEqBytes(empty, a); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, a, got)
}
if got := ConstantTimeLessOrEqBytes(empty, b); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, b, got)
}
if got := ConstantTimeLessOrEqBytes(a, a); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, a, got)
}
if got := ConstantTimeLessOrEqBytes(b, b); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", b, b, got)
}
if got := ConstantTimeLessOrEqBytes(empty, empty); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, empty, got)
}
if l == 0 {
continue
}
max := make([]byte, l)
for i := range max {
max[i] = 0xff
}
if got := ConstantTimeLessOrEqBytes(a, max); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, max, got)
}
if got := ConstantTimeLessOrEqBytes(b, max); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", b, max, got)
}
if got := ConstantTimeLessOrEqBytes(empty, max); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, max, got)
}
if got := ConstantTimeLessOrEqBytes(max, max); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", max, max, got)
}
aPlusOne := make([]byte, l)
copy(aPlusOne, a)
for i := l - 1; i >= 0; i-- {
if aPlusOne[i] == 0xff {
aPlusOne[i] = 0
continue
}
aPlusOne[i]++
if got := ConstantTimeLessOrEqBytes(a, aPlusOne); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, aPlusOne, got)
}
if got := ConstantTimeLessOrEqBytes(aPlusOne, a); got != 0 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", aPlusOne, a, got)
}
break
}
shorter := make([]byte, l-1)
copy(shorter, a)
if got := ConstantTimeLessOrEqBytes(a, shorter); got != 0 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", a, shorter, got)
}
if got := ConstantTimeLessOrEqBytes(shorter, a); got != 0 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", shorter, a, got)
}
if got := ConstantTimeLessOrEqBytes(b, shorter); got != 0 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", b, shorter, got)
}
if got := ConstantTimeLessOrEqBytes(shorter, b); got != 0 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", shorter, b, got)
}
}
}