diff --git a/services/auth/source/oauth2/jwtsigningkey.go b/services/auth/source/oauth2/jwtsigningkey.go index 070fffe60f..92adfc4d84 100644 --- a/services/auth/source/oauth2/jwtsigningkey.go +++ b/services/auth/source/oauth2/jwtsigningkey.go @@ -347,12 +347,30 @@ func loadOrCreateAsymmetricKey() (any, error) { key, err := func() (any, error) { switch { case strings.HasPrefix(setting.OAuth2.JWTSigningAlgorithm, "RS"): - return rsa.GenerateKey(rand.Reader, 4096) + var bits int + switch setting.OAuth2.JWTSigningAlgorithm { + case "RS256": + bits = 2048 + case "RS384": + bits = 3072 + case "RS512": + bits = 4096 + } + return rsa.GenerateKey(rand.Reader, bits) case setting.OAuth2.JWTSigningAlgorithm == "EdDSA": _, pk, err := ed25519.GenerateKey(rand.Reader) return pk, err default: - return ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + var curve elliptic.Curve + switch setting.OAuth2.JWTSigningAlgorithm { + case "ES256": + curve = elliptic.P256() + case "ES384": + curve = elliptic.P384() + case "ES512": + curve = elliptic.P521() + } + return ecdsa.GenerateKey(curve, rand.Reader) } }() if err != nil { diff --git a/services/auth/source/oauth2/jwtsigningkey_test.go b/services/auth/source/oauth2/jwtsigningkey_test.go new file mode 100644 index 0000000000..4db538b0e8 --- /dev/null +++ b/services/auth/source/oauth2/jwtsigningkey_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later + +package oauth2 + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "os" + "path/filepath" + "testing" + + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/test" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadOrCreateAsymmetricKey(t *testing.T) { + loadKey := func(t *testing.T) any { + t.Helper() + loadOrCreateAsymmetricKey() + + fileContent, err := os.ReadFile(setting.OAuth2.JWTSigningPrivateKeyFile) + require.NoError(t, err) + + block, _ := pem.Decode(fileContent) + assert.NotNil(t, block) + assert.EqualValues(t, "PRIVATE KEY", block.Type) + + parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + require.NoError(t, err) + + return parsedKey + } + t.Run("RSA-2048", func(t *testing.T) { + defer test.MockVariableValue(&setting.OAuth2.JWTSigningPrivateKeyFile, filepath.Join(t.TempDir(), "jwt-rsa-2048.priv"))() + defer test.MockVariableValue(&setting.OAuth2.JWTSigningAlgorithm, "RS256")() + + parsedKey := loadKey(t) + + rsaPrivateKey := parsedKey.(*rsa.PrivateKey) + assert.EqualValues(t, 2048, rsaPrivateKey.N.BitLen()) + + t.Run("Load key with differ specified algorithm", func(t *testing.T) { + defer test.MockVariableValue(&setting.OAuth2.JWTSigningAlgorithm, "EdDSA")() + + parsedKey := loadKey(t) + rsaPrivateKey := parsedKey.(*rsa.PrivateKey) + assert.EqualValues(t, 2048, rsaPrivateKey.N.BitLen()) + }) + }) + + t.Run("RSA-3072", func(t *testing.T) { + defer test.MockVariableValue(&setting.OAuth2.JWTSigningPrivateKeyFile, filepath.Join(t.TempDir(), "jwt-rsa-3072.priv"))() + defer test.MockVariableValue(&setting.OAuth2.JWTSigningAlgorithm, "RS384")() + + parsedKey := loadKey(t) + + rsaPrivateKey := parsedKey.(*rsa.PrivateKey) + assert.EqualValues(t, 3072, rsaPrivateKey.N.BitLen()) + }) + + t.Run("RSA-4096", func(t *testing.T) { + defer test.MockVariableValue(&setting.OAuth2.JWTSigningPrivateKeyFile, filepath.Join(t.TempDir(), "jwt-rsa-4096.priv"))() + defer test.MockVariableValue(&setting.OAuth2.JWTSigningAlgorithm, "RS512")() + + parsedKey := loadKey(t) + + rsaPrivateKey := parsedKey.(*rsa.PrivateKey) + assert.EqualValues(t, 4096, rsaPrivateKey.N.BitLen()) + }) + + t.Run("ECDSA-256", func(t *testing.T) { + defer test.MockVariableValue(&setting.OAuth2.JWTSigningPrivateKeyFile, filepath.Join(t.TempDir(), "jwt-ecdsa-256.priv"))() + defer test.MockVariableValue(&setting.OAuth2.JWTSigningAlgorithm, "ES256")() + + parsedKey := loadKey(t) + + ecdsaPrivateKey := parsedKey.(*ecdsa.PrivateKey) + assert.EqualValues(t, 256, ecdsaPrivateKey.Params().BitSize) + }) + + t.Run("ECDSA-384", func(t *testing.T) { + defer test.MockVariableValue(&setting.OAuth2.JWTSigningPrivateKeyFile, filepath.Join(t.TempDir(), "jwt-ecdsa-384.priv"))() + defer test.MockVariableValue(&setting.OAuth2.JWTSigningAlgorithm, "ES384")() + + parsedKey := loadKey(t) + + ecdsaPrivateKey := parsedKey.(*ecdsa.PrivateKey) + assert.EqualValues(t, 384, ecdsaPrivateKey.Params().BitSize) + }) + + t.Run("ECDSA-512", func(t *testing.T) { + defer test.MockVariableValue(&setting.OAuth2.JWTSigningPrivateKeyFile, filepath.Join(t.TempDir(), "jwt-ecdsa-512.priv"))() + defer test.MockVariableValue(&setting.OAuth2.JWTSigningAlgorithm, "ES512")() + + parsedKey := loadKey(t) + + ecdsaPrivateKey := parsedKey.(*ecdsa.PrivateKey) + assert.EqualValues(t, 521, ecdsaPrivateKey.Params().BitSize) + }) + + t.Run("EdDSA", func(t *testing.T) { + defer test.MockVariableValue(&setting.OAuth2.JWTSigningPrivateKeyFile, filepath.Join(t.TempDir(), "jwt-eddsa.priv"))() + defer test.MockVariableValue(&setting.OAuth2.JWTSigningAlgorithm, "EdDSA")() + + parsedKey := loadKey(t) + + assert.NotNil(t, parsedKey.(ed25519.PrivateKey)) + }) +}