Pulse/internal/notifications/email_enhanced_test.go
2026-03-29 14:50:36 +01:00

839 lines
22 KiB
Go

package notifications
import (
"bufio"
"bytes"
"crypto/tls"
"fmt"
"io"
"mime"
"mime/multipart"
"mime/quotedprintable"
"net"
"net/mail"
"net/textproto"
"strings"
"testing"
"time"
)
func TestNewEnhancedEmailManager(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "smtp.example.com",
SMTPPort: 587,
From: "test@example.com",
To: []string{"recipient@example.com"},
},
RateLimit: 10,
}
manager := NewEnhancedEmailManager(config)
if manager == nil {
t.Fatal("NewEnhancedEmailManager returned nil")
}
if manager.rateLimit == nil {
t.Fatal("rate limiter not initialized")
}
if manager.rateLimit.rate != 10 {
t.Errorf("expected rate limit 10, got %d", manager.rateLimit.rate)
}
}
func TestCheckRateLimit_NoLimit(t *testing.T) {
manager := NewEnhancedEmailManager(EmailProviderConfig{
RateLimit: 0, // No limit
})
// Should always succeed when no rate limit
for i := 0; i < 100; i++ {
if err := manager.checkRateLimit(); err != nil {
t.Errorf("checkRateLimit should not error with no limit: %v", err)
}
}
}
func TestCheckRateLimit_ExceedsLimit(t *testing.T) {
manager := NewEnhancedEmailManager(EmailProviderConfig{
RateLimit: 3,
})
// First 3 should succeed
for i := 0; i < 3; i++ {
if err := manager.checkRateLimit(); err != nil {
t.Errorf("call %d should succeed: %v", i+1, err)
}
}
// 4th should fail
err := manager.checkRateLimit()
if err == nil {
t.Error("expected rate limit error on 4th call")
}
if !strings.Contains(err.Error(), "rate limit exceeded") {
t.Errorf("unexpected error message: %v", err)
}
}
func TestCheckRateLimit_ResetsAfterMinute(t *testing.T) {
manager := NewEnhancedEmailManager(EmailProviderConfig{
RateLimit: 2,
})
// Use up the limit
_ = manager.checkRateLimit()
_ = manager.checkRateLimit()
// Manually set lastSent to over a minute ago
manager.rateLimit.mu.Lock()
manager.rateLimit.lastSent = time.Now().Add(-2 * time.Minute)
manager.rateLimit.mu.Unlock()
// Should succeed after reset
if err := manager.checkRateLimit(); err != nil {
t.Errorf("should succeed after minute reset: %v", err)
}
}
func TestSendViaProvider_ProviderUsernameDefaults(t *testing.T) {
tests := []struct {
name string
provider string
initialUsername string
expectedUsername string
}{
{
name: "SendGrid sets apikey username",
provider: "SendGrid",
initialUsername: "",
expectedUsername: "apikey",
},
{
name: "SendGrid preserves existing username",
provider: "SendGrid",
initialUsername: "custom",
expectedUsername: "custom",
},
{
name: "SparkPost sets SMTP_Injection username",
provider: "SparkPost",
initialUsername: "",
expectedUsername: "SMTP_Injection",
},
{
name: "Resend sets resend username",
provider: "Resend",
initialUsername: "",
expectedUsername: "resend",
},
{
name: "Unknown provider leaves username unchanged",
provider: "Custom",
initialUsername: "",
expectedUsername: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test", // Will fail to connect
SMTPPort: 587,
Username: tt.initialUsername,
Password: "test",
From: "test@example.com",
To: []string{"recipient@example.com"},
},
Provider: tt.provider,
AuthRequired: true,
}
manager := NewEnhancedEmailManager(config)
// Call sendViaProvider - it will fail on connection, but will have set username
_ = manager.sendViaProvider([]byte("test"))
if manager.config.Username != tt.expectedUsername {
t.Errorf("expected username %q, got %q", tt.expectedUsername, manager.config.Username)
}
})
}
}
func TestSendViaProvider_PostmarkUsernameFromPassword(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test",
SMTPPort: 587,
Username: "",
Password: "postmark-api-token",
From: "test@example.com",
To: []string{"recipient@example.com"},
},
Provider: "Postmark",
AuthRequired: true,
}
manager := NewEnhancedEmailManager(config)
_ = manager.sendViaProvider([]byte("test"))
// Postmark copies password to username when username is empty
if manager.config.Username != "postmark-api-token" {
t.Errorf("expected username to be set from password, got %q", manager.config.Username)
}
}
func TestSendViaProvider_RoutingByTLSConfig(t *testing.T) {
tests := []struct {
name string
tls bool
startTLS bool
port int
expectsError string // Partial match of expected error
}{
{
name: "TLS true routes to sendTLS",
tls: true,
startTLS: false,
port: 587,
expectsError: "TLS dial failed",
},
{
name: "Port 465 routes to sendTLS",
tls: false,
startTLS: false,
port: 465,
expectsError: "TLS dial failed",
},
{
name: "StartTLS routes to sendStartTLS",
tls: false,
startTLS: true,
port: 587,
expectsError: "TCP dial failed",
},
{
name: "Plain routes to sendPlain",
tls: false,
startTLS: false,
port: 25,
expectsError: "TCP dial failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test",
SMTPPort: tt.port,
TLS: tt.tls,
StartTLS: tt.startTLS,
From: "test@example.com",
To: []string{"recipient@example.com"},
},
}
manager := NewEnhancedEmailManager(config)
err := manager.sendViaProvider([]byte("test"))
if err == nil {
t.Error("expected connection error")
return
}
if !strings.Contains(err.Error(), tt.expectsError) {
t.Errorf("expected error containing %q, got %q", tt.expectsError, err.Error())
}
})
}
}
func TestSendEmailWithRetry_RetriesOnFailure(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test",
SMTPPort: 587,
From: "test@example.com",
To: []string{"recipient@example.com"},
},
MaxRetries: 2,
RetryDelay: 0, // No delay for tests
}
manager := NewEnhancedEmailManager(config)
err := manager.SendEmailWithRetry("Test", "<p>test</p>", "test")
if err == nil {
t.Error("expected error after all retries exhausted")
}
// Should mention the retry count
if !strings.Contains(err.Error(), "3 attempts") {
t.Errorf("error should mention attempt count: %v", err)
}
}
func TestSendEmailWithRetry_RateLimitPreventsRetry(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test",
SMTPPort: 587,
From: "test@example.com",
To: []string{"recipient@example.com"},
},
MaxRetries: 3,
RetryDelay: 0,
RateLimit: 1, // Only 1 per minute
}
manager := NewEnhancedEmailManager(config)
// First send uses the 1 allowed
err := manager.SendEmailWithRetry("Test", "<p>test</p>", "test")
if err == nil {
t.Error("expected error (connection should fail)")
}
// Second send should hit rate limit on all retries
err = manager.SendEmailWithRetry("Test2", "<p>test</p>", "test")
if err == nil {
t.Error("expected rate limit error")
}
if !strings.Contains(err.Error(), "rate limit exceeded") {
t.Errorf("expected rate limit error, got: %v", err)
}
}
func TestSendEmailOnce_BuildsMultipartMessage(t *testing.T) {
// We can't test actual sending, but we can verify the method doesn't panic
// with valid inputs and returns expected connection error
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test",
SMTPPort: 587,
From: "sender@example.com",
To: []string{"recipient@example.com"},
},
ReplyTo: "reply@example.com",
}
manager := NewEnhancedEmailManager(config)
err := manager.sendEmailOnce("Test Subject", "<p>HTML Body</p>", "Text Body")
// Should fail on connection, not on message building
if err == nil {
t.Error("expected connection error")
}
// Error should be from connection, not from message construction
if strings.Contains(err.Error(), "message") && strings.Contains(err.Error(), "build") {
t.Errorf("unexpected message build error: %v", err)
}
}
func TestBuildMultipartEmailMessage_EncodesMultipartBodies(t *testing.T) {
addresses := resolvedEmailAddresses{
from: &mail.Address{
Name: "Pulse Sender",
Address: "sender@example.com",
},
to: []*mail.Address{
{
Name: "Recipient",
Address: "recipient@example.com",
},
},
replyTo: &mail.Address{Address: "reply@example.com"},
}
textBody := "Text line 1\nBcc: attacker@example.com\n.\n--pretend-boundary"
htmlBody := "<p>Hello</p>\nContent-Type: text/plain\n.\n--pretend-boundary"
msg, err := buildMultipartEmailMessage(addresses, "Alert Subject", htmlBody, textBody, time.Unix(1711711711, 1234).UTC())
if err != nil {
t.Fatalf("buildMultipartEmailMessage() error = %v", err)
}
raw := string(msg)
if strings.Contains(raw, "Content-Transfer-Encoding: 7bit") {
t.Fatalf("message should not use raw 7bit body encoding:\n%s", raw)
}
if count := strings.Count(raw, "Content-Transfer-Encoding: quoted-printable"); count != 2 {
t.Fatalf("expected two quoted-printable parts, got %d", count)
}
parsed, err := mail.ReadMessage(bytes.NewReader(msg))
if err != nil {
t.Fatalf("mail.ReadMessage() error = %v", err)
}
if got := parsed.Header.Get("From"); got != addresses.from.String() {
t.Fatalf("From header = %q, want %q", got, addresses.from.String())
}
if got := parsed.Header.Get("To"); got != formatHeaderAddresses(addresses.to) {
t.Fatalf("To header = %q, want %q", got, formatHeaderAddresses(addresses.to))
}
if got := parsed.Header.Get("Reply-To"); got != addresses.replyTo.String() {
t.Fatalf("Reply-To header = %q, want %q", got, addresses.replyTo.String())
}
if got := parsed.Header.Get("Subject"); got != "Alert Subject" {
t.Fatalf("Subject header = %q, want %q", got, "Alert Subject")
}
mediaType, params, err := mime.ParseMediaType(parsed.Header.Get("Content-Type"))
if err != nil {
t.Fatalf("mime.ParseMediaType() error = %v", err)
}
if mediaType != "multipart/alternative" {
t.Fatalf("content type = %q, want %q", mediaType, "multipart/alternative")
}
reader := multipart.NewReader(parsed.Body, params["boundary"])
textPart, err := reader.NextRawPart()
if err != nil {
t.Fatalf("NextPart() text error = %v", err)
}
if got := textPart.Header.Get("Content-Transfer-Encoding"); got != "quoted-printable" {
t.Fatalf("text part transfer encoding = %q, want %q", got, "quoted-printable")
}
decodedText, err := io.ReadAll(quotedprintable.NewReader(textPart))
if err != nil {
t.Fatalf("ReadAll(text part) error = %v", err)
}
if got := string(decodedText); got != normalizeEmailBodyLineEndings(textBody) {
t.Fatalf("decoded text body = %q, want %q", got, normalizeEmailBodyLineEndings(textBody))
}
htmlPart, err := reader.NextRawPart()
if err != nil {
t.Fatalf("NextPart() html error = %v", err)
}
if got := htmlPart.Header.Get("Content-Transfer-Encoding"); got != "quoted-printable" {
t.Fatalf("html part transfer encoding = %q, want %q", got, "quoted-printable")
}
decodedHTML, err := io.ReadAll(quotedprintable.NewReader(htmlPart))
if err != nil {
t.Fatalf("ReadAll(html part) error = %v", err)
}
if got := string(decodedHTML); got != normalizeEmailBodyLineEndings(htmlBody) {
t.Fatalf("decoded html body = %q, want %q", got, normalizeEmailBodyLineEndings(htmlBody))
}
if _, err := reader.NextRawPart(); err != io.EOF {
t.Fatalf("expected multipart EOF, got %v", err)
}
}
func TestTestConnection_TLSRouting(t *testing.T) {
tests := []struct {
name string
tls bool
port int
wantTLS bool
}{
{"TLS true uses TLS dial", true, 587, true},
{"Port 465 uses TLS dial", false, 465, true},
{"Port 587 without TLS uses plain dial", false, 587, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test",
SMTPPort: tt.port,
TLS: tt.tls,
},
}
manager := NewEnhancedEmailManager(config)
err := manager.TestConnection()
if err == nil {
t.Error("expected connection error")
}
// Verify error message indicates correct connection type
if tt.wantTLS && strings.Contains(err.Error(), "TCP dial") {
t.Error("TLS connection should not produce TCP dial error")
}
})
}
}
func TestTestConnection_TLSUsesDialerTimeout(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "smtp.example.com",
SMTPPort: 465,
TLS: true,
},
}
manager := NewEnhancedEmailManager(config)
origTLSDial := smtpTLSDialWithDialer
t.Cleanup(func() { smtpTLSDialWithDialer = origTLSDial })
var gotTimeout time.Duration
smtpTLSDialWithDialer = func(dialer *net.Dialer, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
if dialer != nil {
gotTimeout = dialer.Timeout
}
return nil, fmt.Errorf("tls dial intercepted")
}
err := manager.TestConnection()
if err == nil {
t.Fatal("expected connection error")
}
if !strings.Contains(err.Error(), "tls dial intercepted") {
t.Fatalf("expected intercepted TLS dial error, got %v", err)
}
if gotTimeout != 10*time.Second {
t.Fatalf("expected TLS dial timeout of 10s, got %s", gotTimeout)
}
}
func TestSendTLS_ConnectionError(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test",
SMTPPort: 465,
From: "test@example.com",
To: []string{"recipient@example.com"},
},
}
manager := NewEnhancedEmailManager(config)
addresses, err := manager.resolveEmailAddresses()
if err != nil {
t.Fatalf("resolveEmailAddresses() error = %v", err)
}
err = manager.sendTLS("invalid.host.test:465", []byte("test"), addresses)
if err == nil {
t.Error("expected TLS dial error")
}
if !strings.Contains(err.Error(), "TLS dial failed") {
t.Errorf("expected TLS dial error, got: %v", err)
}
}
func TestSendStartTLS_ConnectionError(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test",
SMTPPort: 587,
From: "test@example.com",
To: []string{"recipient@example.com"},
},
}
manager := NewEnhancedEmailManager(config)
addresses, err := manager.resolveEmailAddresses()
if err != nil {
t.Fatalf("resolveEmailAddresses() error = %v", err)
}
err = manager.sendStartTLS("invalid.host.test:587", []byte("test"), addresses)
if err == nil {
t.Error("expected TCP dial error")
}
if !strings.Contains(err.Error(), "TCP dial failed") {
t.Errorf("expected TCP dial error, got: %v", err)
}
}
func TestSendPlain_ConnectionError(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test",
SMTPPort: 25,
From: "test@example.com",
To: []string{"recipient@example.com"},
},
}
manager := NewEnhancedEmailManager(config)
addresses, err := manager.resolveEmailAddresses()
if err != nil {
t.Fatalf("resolveEmailAddresses() error = %v", err)
}
err = manager.sendPlain("invalid.host.test:25", []byte("test"), addresses)
if err == nil {
t.Error("expected TCP dial error")
}
if !strings.Contains(err.Error(), "TCP dial failed") {
t.Errorf("expected TCP dial error, got: %v", err)
}
}
func TestCheckRateLimit_NegativeLimit(t *testing.T) {
// Negative rate limit should be treated as no limit
manager := NewEnhancedEmailManager(EmailProviderConfig{
RateLimit: -1,
})
for i := 0; i < 10; i++ {
if err := manager.checkRateLimit(); err != nil {
t.Errorf("negative rate limit should allow all calls: %v", err)
}
}
}
func TestCheckRateLimit_Concurrency(t *testing.T) {
manager := NewEnhancedEmailManager(EmailProviderConfig{
RateLimit: 100,
})
// Run concurrent rate limit checks
done := make(chan bool, 50)
for i := 0; i < 50; i++ {
go func() {
_ = manager.checkRateLimit()
done <- true
}()
}
// Wait for all goroutines
for i := 0; i < 50; i++ {
<-done
}
// Verify counter is correct (should be 50)
manager.rateLimit.mu.Lock()
count := manager.rateLimit.sentCount
manager.rateLimit.mu.Unlock()
if count != 50 {
t.Errorf("expected count 50 after concurrent calls, got %d", count)
}
}
func TestSanitizeEmailHeaderValue(t *testing.T) {
t.Parallel()
got := sanitizeEmailHeaderValue("Alert subject\r\nBcc: attacker@example.com")
want := "Alert subject Bcc: attacker@example.com"
if got != want {
t.Fatalf("sanitizeEmailHeaderValue() = %q, want %q", got, want)
}
}
func TestSendEmailOnceRejectsInvalidFromAddress(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test",
SMTPPort: 587,
From: "sender@example.com\r\nBcc: attacker@example.com",
To: []string{"recipient@example.com"},
},
}
manager := NewEnhancedEmailManager(config)
err := manager.sendEmailOnce("Test Subject", "<p>HTML Body</p>", "Text Body")
if err == nil {
t.Fatal("expected invalid from address error")
}
if !strings.Contains(err.Error(), "invalid from address") {
t.Fatalf("expected from-address validation error, got %v", err)
}
}
func TestSendViaProviderRejectsInvalidRecipientAddress(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test",
SMTPPort: 587,
From: "sender@example.com",
To: []string{"recipient@example.com\r\nCc: attacker@example.com"},
},
}
manager := NewEnhancedEmailManager(config)
err := manager.sendViaProvider([]byte("test"))
if err == nil {
t.Fatal("expected invalid recipient address error")
}
if !strings.Contains(err.Error(), "invalid recipient address") {
t.Fatalf("expected recipient validation error, got %v", err)
}
}
func TestSendPlain_Success(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "smtp.example.com",
SMTPPort: 25,
From: "test@example.com",
To: []string{"recipient@example.com"},
},
}
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()
origDial := smtpDialTimeout
smtpDialTimeout = func(network, addr string, timeout time.Duration) (net.Conn, error) {
return clientConn, nil
}
t.Cleanup(func() { smtpDialTimeout = origDial })
go func() {
defer serverConn.Close()
w := bufio.NewWriter(serverConn)
r := textproto.NewReader(bufio.NewReader(serverConn))
// Greeting
fmt.Fprint(w, "220 smtp.example.com ESMTP\r\n")
_ = w.Flush()
for {
line, err := r.ReadLine()
if err != nil {
return
}
switch {
case strings.HasPrefix(line, "HELO") || strings.HasPrefix(line, "EHLO"):
fmt.Fprint(w, "250-smtp.example.com\r\n250 8BITMIME\r\n")
_ = w.Flush()
case strings.HasPrefix(line, "MAIL FROM:"):
fmt.Fprint(w, "250 OK\r\n")
_ = w.Flush()
case strings.HasPrefix(line, "RCPT TO:"):
fmt.Fprint(w, "250 OK\r\n")
_ = w.Flush()
case strings.HasPrefix(line, "DATA"):
fmt.Fprint(w, "354 Start mail input; end with <CRLF>.<CRLF>\r\n")
_ = w.Flush()
for {
l, err := r.ReadLine()
if err != nil || l == "." {
break
}
}
fmt.Fprint(w, "250 OK\r\n")
_ = w.Flush()
case strings.HasPrefix(line, "QUIT"):
fmt.Fprint(w, "221 Bye\r\n")
_ = w.Flush()
return
default:
// Default OK to tolerate extra commands.
fmt.Fprint(w, "250 OK\r\n")
_ = w.Flush()
}
}
}()
manager := NewEnhancedEmailManager(config)
addresses, err := manager.resolveEmailAddresses()
if err != nil {
t.Fatalf("resolveEmailAddresses() error = %v", err)
}
err = manager.sendPlain("ignored:25", []byte("Test Message"), addresses)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}
func TestSendTLS_Success(t *testing.T) {
// We don't need a real TLS server here. This is an error-path sanity check that
// exercises the TLS dialer logic without binding ports (which can be blocked in CI).
addr := "127.0.0.1:1"
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "invalid.host.test",
SMTPPort: 1,
TLS: true,
From: "sender@example.com",
To: []string{"recipient@example.com"},
},
SkipTLSVerify: true,
}
manager := NewEnhancedEmailManager(config)
addresses, err := manager.resolveEmailAddresses()
if err != nil {
t.Fatalf("resolveEmailAddresses() error = %v", err)
}
err = manager.sendTLS(addr, []byte("test"), addresses)
// It will still fail because we aren't running a real TLS server here,
// but we can verify it reaches the TLS dialer.
if err == nil {
t.Error("expected TLS error")
}
if !strings.Contains(err.Error(), "TLS dial failed") && !strings.Contains(err.Error(), "remote error") {
t.Errorf("unexpected error: %v", err)
}
}
func TestSendStartTLS_Success(t *testing.T) {
config := EmailProviderConfig{
EmailConfig: EmailConfig{
SMTPHost: "smtp.example.com",
SMTPPort: 587,
StartTLS: true,
From: "sender@example.com",
To: []string{"recipient@example.com"},
},
SkipTLSVerify: true,
}
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()
origDial := smtpDialTimeout
smtpDialTimeout = func(network, addr string, timeout time.Duration) (net.Conn, error) {
return clientConn, nil
}
t.Cleanup(func() { smtpDialTimeout = origDial })
go func() {
defer serverConn.Close()
w := bufio.NewWriter(serverConn)
r := textproto.NewReader(bufio.NewReader(serverConn))
fmt.Fprint(w, "220 smtp.example.com ESMTP\r\n")
_ = w.Flush()
for {
line, err := r.ReadLine()
if err != nil {
return
}
if strings.HasPrefix(line, "EHLO") {
fmt.Fprint(w, "250-smtp.example.com\r\n250 STARTTLS\r\n")
_ = w.Flush()
continue
}
if strings.HasPrefix(line, "STARTTLS") {
fmt.Fprint(w, "220 Ready to start TLS\r\n")
_ = w.Flush()
return // Client will attempt TLS handshake and fail.
}
fmt.Fprint(w, "250 OK\r\n")
_ = w.Flush()
}
}()
manager := NewEnhancedEmailManager(config)
addresses, err := manager.resolveEmailAddresses()
if err != nil {
t.Fatalf("resolveEmailAddresses() error = %v", err)
}
err = manager.sendStartTLS("ignored:587", []byte("Test Message"), addresses)
// Should fail at TLS upgrade because mock server doesn't actually do TLS
if err == nil {
t.Error("expected STARTTLS upgrade error")
}
}