diff --git a/main.go b/main.go index a78c185..67d102e 100644 --- a/main.go +++ b/main.go @@ -3,10 +3,12 @@ package main import ( "context" "crypto/tls" + "crypto/x509" "encoding/csv" "errors" "flag" "fmt" + "io/ioutil" "log" "net" "net/http" @@ -42,20 +44,22 @@ func arg_fail(msg string) { } type CLIArgs struct { - country string - listCountries bool - listProxies bool - bindAddress string - verbosity int - timeout time.Duration - showVersion bool - proxy string - apiLogin string - apiPassword string - apiAddress string - bootstrapDNS string - refresh time.Duration - refreshRetry time.Duration + country string + listCountries bool + listProxies bool + bindAddress string + verbosity int + timeout time.Duration + showVersion bool + proxy string + apiLogin string + apiPassword string + apiAddress string + bootstrapDNS string + refresh time.Duration + refreshRetry time.Duration + certChainWorkaround bool + caFile string } func parse_args() CLIArgs { @@ -80,6 +84,9 @@ func parse_args() CLIArgs { "Examples: https://1.1.1.1/dns-query, quic://dns.adguard.com") flag.DurationVar(&args.refresh, "refresh", 4*time.Hour, "login refresh interval") flag.DurationVar(&args.refreshRetry, "refresh-retry", 5*time.Second, "login refresh retry interval") + flag.BoolVar(&args.certChainWorkaround, "certchain-workaround", true, + "add bundled cross-signed intermediate cert to certchain to make it check out on old systems") + flag.StringVar(&args.caFile, "cafile", "", "use custom CA certificate bundle file") flag.Parse() if args.country == "" { arg_fail("Country can't be empty string.") @@ -259,7 +266,21 @@ func run() int { return basic_auth_header(seclient.GetProxyCredentials()) } - handlerDialer := NewProxyDialer(endpoint.NetAddr(), fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX), auth, dialer) + var caPool *x509.CertPool + if args.caFile != "" { + caPool = x509.NewCertPool() + certs, err := ioutil.ReadFile(args.caFile) + if err != nil { + mainLogger.Error("Can't load CA file: %v", err) + return 15 + } + if ok := caPool.AppendCertsFromPEM(certs); !ok { + mainLogger.Error("Can't load certificates from CA file") + return 15 + } + } + + handlerDialer := NewProxyDialer(endpoint.NetAddr(), fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX), auth, args.certChainWorkaround, caPool, dialer) mainLogger.Info("Endpoint: %s", endpoint.NetAddr()) mainLogger.Info("Starting proxy server...") handler := NewProxyHandler(handlerDialer, proxyLogger) diff --git a/upstream.go b/upstream.go index ff2a0a0..3396564 100644 --- a/upstream.go +++ b/upstream.go @@ -6,6 +6,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/pem" "errors" "fmt" "io" @@ -20,10 +21,37 @@ const ( PROXY_CONNECT_METHOD = "CONNECT" PROXY_HOST_HEADER = "Host" PROXY_AUTHORIZATION_HEADER = "Proxy-Authorization" + MISSING_CHAIN_CERT = `-----BEGIN CERTIFICATE----- +MIID0zCCArugAwIBAgIQVmcdBOpPmUxvEIFHWdJ1lDANBgkqhkiG9w0BAQwFADB7 +MQswCQYDVQQGEwJHQjEbMBkGA1UECAwSR3JlYXRlciBNYW5jaGVzdGVyMRAwDgYD +VQQHDAdTYWxmb3JkMRowGAYDVQQKDBFDb21vZG8gQ0EgTGltaXRlZDEhMB8GA1UE +AwwYQUFBIENlcnRpZmljYXRlIFNlcnZpY2VzMB4XDTE5MDMxMjAwMDAwMFoXDTI4 +MTIzMTIzNTk1OVowgYgxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpOZXcgSmVyc2V5 +MRQwEgYDVQQHEwtKZXJzZXkgQ2l0eTEeMBwGA1UEChMVVGhlIFVTRVJUUlVTVCBO +ZXR3b3JrMS4wLAYDVQQDEyVVU0VSVHJ1c3QgRUNDIENlcnRpZmljYXRpb24gQXV0 +aG9yaXR5MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEGqxUWqn5aCPnetUkb1PGWthL +q8bVttHmc3Gu3ZzWDGH926CJA7gFFOxXzu5dP+Ihs8731Ip54KODfi2X0GHE8Znc +JZFjq38wo7Rw4sehM5zzvy5cU7Ffs30yf4o043l5o4HyMIHvMB8GA1UdIwQYMBaA +FKARCiM+lvEH7OKvKe+CpX/QMKS0MB0GA1UdDgQWBBQ64QmG1M8ZwpZ2dEl23OA1 +xmNjmjAOBgNVHQ8BAf8EBAMCAYYwDwYDVR0TAQH/BAUwAwEB/zARBgNVHSAECjAI +MAYGBFUdIAAwQwYDVR0fBDwwOjA4oDagNIYyaHR0cDovL2NybC5jb21vZG9jYS5j +b20vQUFBQ2VydGlmaWNhdGVTZXJ2aWNlcy5jcmwwNAYIKwYBBQUHAQEEKDAmMCQG +CCsGAQUFBzABhhhodHRwOi8vb2NzcC5jb21vZG9jYS5jb20wDQYJKoZIhvcNAQEM +BQADggEBABns652JLCALBIAdGN5CmXKZFjK9Dpx1WywV4ilAbe7/ctvbq5AfjJXy +ij0IckKJUAfiORVsAYfZFhr1wHUrxeZWEQff2Ji8fJ8ZOd+LygBkc7xGEJuTI42+ +FsMuCIKchjN0djsoTI0DQoWz4rIjQtUfenVqGtF8qmchxDM6OW1TyaLtYiKou+JV +bJlsQ2uRl9EMC5MCHdK8aXdJ5htN978UeAOwproLtOGFfy/cQjutdAFI3tZs4RmY +CV4Ks2dH/hzg1cEo70qLRDEmBDeNiXQ2Lu+lIg+DdEmSx/cQwgwp+7e9un/jX9Wf +8qn0dNW44bOwgeThpWOjzOoEeJBuv/c= +-----END CERTIFICATE----- +` ) var UpstreamBlockedError = errors.New("blocked by upstream") +var missingLinkDER, _ = pem.Decode([]byte(MISSING_CHAIN_CERT)) +var missingLink, _ = x509.ParseCertificate(missingLinkDER.Bytes) + type Dialer interface { Dial(network, address string) (net.Conn, error) } @@ -34,18 +62,22 @@ type ContextDialer interface { } type ProxyDialer struct { - address string - tlsServerName string - auth AuthProvider - next ContextDialer + address string + tlsServerName string + auth AuthProvider + next ContextDialer + intermediateWorkaround bool + caPool *x509.CertPool } -func NewProxyDialer(address, tlsServerName string, auth AuthProvider, nextDialer ContextDialer) *ProxyDialer { +func NewProxyDialer(address, tlsServerName string, auth AuthProvider, intermediateWorkaround bool, caPool *x509.CertPool, nextDialer ContextDialer) *ProxyDialer { return &ProxyDialer{ - address: address, - tlsServerName: tlsServerName, - auth: auth, - next: nextDialer, + address: address, + tlsServerName: tlsServerName, + auth: auth, + next: nextDialer, + intermediateWorkaround: intermediateWorkaround, + caPool: caPool, } } @@ -79,7 +111,7 @@ func ProxyDialerFromURL(u *url.URL, next ContextDialer) (*ProxyDialer, error) { return authHeader } } - return NewProxyDialer(address, tlsServerName, auth, next), nil + return NewProxyDialer(address, tlsServerName, auth, false, nil, next), nil } func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { @@ -105,9 +137,18 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) opts := x509.VerifyOptions{ DNSName: d.tlsServerName, Intermediates: x509.NewCertPool(), + Roots: d.caPool, } + waRequired := false for _, cert := range cs.PeerCertificates[1:] { opts.Intermediates.AddCert(cert) + if d.intermediateWorkaround && !waRequired && + bytes.Compare(cert.AuthorityKeyId, missingLink.SubjectKeyId) == 0 { + waRequired = true + } + } + if waRequired { + opts.Intermediates.AddCert(missingLink) } _, err := cs.PeerCertificates[0].Verify(opts) return err