Merge pull request #71 from Snawoot/refactor

Split main package into subpackages
This commit is contained in:
Snawoot 2024-11-04 12:26:40 +02:00 committed by GitHub
commit 3094f4ad97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 263 additions and 242 deletions

50
clock/clock.go Normal file
View file

@ -0,0 +1,50 @@
package clock
import (
"context"
"time"
)
const WALLCLOCK_PRECISION = 1 * time.Second
func AfterWallClock(d time.Duration) <-chan time.Time {
ch := make(chan time.Time, 1)
deadline := time.Now().Add(d).Truncate(0)
after_ch := time.After(d)
ticker := time.NewTicker(WALLCLOCK_PRECISION)
go func() {
var t time.Time
defer ticker.Stop()
for {
select {
case t = <-after_ch:
ch <- t
return
case t = <-ticker.C:
if t.After(deadline) {
ch <- t
return
}
}
}
}()
return ch
}
func RunTicker(ctx context.Context, interval, retryInterval time.Duration, cb func(context.Context) error) {
go func() {
var err error
for {
nextInterval := interval
if err != nil {
nextInterval = retryInterval
}
select {
case <-ctx.Done():
return
case <-AfterWallClock(nextInterval):
err = cb(ctx)
}
}
}()
}

View file

@ -1,4 +1,4 @@
package main package dialer
import ( import (
"context" "context"

View file

@ -1,4 +1,4 @@
package main package dialer
import ( import (
"context" "context"

View file

@ -1,4 +1,4 @@
package main package dialer
import ( import (
"bufio" "bufio"
@ -6,6 +6,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/base64"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
@ -47,11 +48,11 @@ CV4Ks2dH/hzg1cEo70qLRDEmBDeNiXQ2Lu+lIg+DdEmSx/cQwgwp+7e9un/jX9Wf
` `
) )
var UpstreamBlockedError = errors.New("blocked by upstream")
var missingLinkDER, _ = pem.Decode([]byte(MISSING_CHAIN_CERT)) var missingLinkDER, _ = pem.Decode([]byte(MISSING_CHAIN_CERT))
var missingLink, _ = x509.ParseCertificate(missingLinkDER.Bytes) var missingLink, _ = x509.ParseCertificate(missingLinkDER.Bytes)
type stringCb = func() (string, error)
type Dialer interface { type Dialer interface {
Dial(network, address string) (net.Conn, error) Dial(network, address string) (net.Conn, error)
} }
@ -62,15 +63,15 @@ type ContextDialer interface {
} }
type ProxyDialer struct { type ProxyDialer struct {
address string address stringCb
tlsServerName string tlsServerName stringCb
auth AuthProvider auth stringCb
next ContextDialer next ContextDialer
intermediateWorkaround bool intermediateWorkaround bool
caPool *x509.CertPool caPool *x509.CertPool
} }
func NewProxyDialer(address, tlsServerName string, auth AuthProvider, intermediateWorkaround bool, caPool *x509.CertPool, nextDialer ContextDialer) *ProxyDialer { func NewProxyDialer(address, tlsServerName, auth stringCb, intermediateWorkaround bool, caPool *x509.CertPool, nextDialer ContextDialer) *ProxyDialer {
return &ProxyDialer{ return &ProxyDialer{
address: address, address: address,
tlsServerName: tlsServerName, tlsServerName: tlsServerName,
@ -85,7 +86,7 @@ func ProxyDialerFromURL(u *url.URL, next ContextDialer) (*ProxyDialer, error) {
host := u.Hostname() host := u.Hostname()
port := u.Port() port := u.Port()
tlsServerName := "" tlsServerName := ""
var auth AuthProvider = nil var auth stringCb = nil
switch strings.ToLower(u.Scheme) { switch strings.ToLower(u.Scheme) {
case "http": case "http":
@ -106,12 +107,9 @@ func ProxyDialerFromURL(u *url.URL, next ContextDialer) (*ProxyDialer, error) {
if u.User != nil { if u.User != nil {
username := u.User.Username() username := u.User.Username()
password, _ := u.User.Password() password, _ := u.User.Password()
authHeader := basic_auth_header(username, password) auth = WrapStringToCb(BasicAuthHeader(username, password))
auth = func() string {
return authHeader
} }
} return NewProxyDialer(WrapStringToCb(address), WrapStringToCb(tlsServerName), auth, false, nil, next), nil
return NewProxyDialer(address, tlsServerName, auth, false, nil, next), nil
} }
func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
@ -121,12 +119,20 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
return nil, errors.New("bad network specified for DialContext: only tcp is supported") return nil, errors.New("bad network specified for DialContext: only tcp is supported")
} }
conn, err := d.next.DialContext(ctx, "tcp", d.address) uAddress, err := d.address()
if err != nil {
return nil, err
}
conn, err := d.next.DialContext(ctx, "tcp", uAddress)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if d.tlsServerName != "" { uTLSServerName, err := d.tlsServerName()
if err != nil {
return nil, err
}
if uTLSServerName != "" {
// Custom cert verification logic: // Custom cert verification logic:
// DO NOT send SNI extension of TLS ClientHello // DO NOT send SNI extension of TLS ClientHello
// DO peer certificate verification against specified servername // DO peer certificate verification against specified servername
@ -135,7 +141,7 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
InsecureSkipVerify: true, InsecureSkipVerify: true,
VerifyConnection: func(cs tls.ConnectionState) error { VerifyConnection: func(cs tls.ConnectionState) error {
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
DNSName: d.tlsServerName, DNSName: uTLSServerName,
Intermediates: x509.NewCertPool(), Intermediates: x509.NewCertPool(),
Roots: d.caPool, Roots: d.caPool,
} }
@ -169,7 +175,11 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
} }
if d.auth != nil { if d.auth != nil {
req.Header.Set(PROXY_AUTHORIZATION_HEADER, d.auth()) auth, err := d.auth()
if err != nil {
return nil, err
}
req.Header.Set(PROXY_AUTHORIZATION_HEADER, auth)
} }
rawreq, err := httputil.DumpRequest(req, false) rawreq, err := httputil.DumpRequest(req, false)
@ -188,10 +198,6 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
} }
if proxyResp.StatusCode != http.StatusOK { if proxyResp.StatusCode != http.StatusOK {
if proxyResp.StatusCode == http.StatusForbidden &&
proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" {
return nil, UpstreamBlockedError
}
return nil, errors.New(fmt.Sprintf("bad response from upstream proxy server: %s", proxyResp.Status)) return nil, errors.New(fmt.Sprintf("bad response from upstream proxy server: %s", proxyResp.Status))
} }
@ -228,3 +234,14 @@ func readResponse(r io.Reader, req *http.Request) (*http.Response, error) {
} }
return http.ReadResponse(bufio.NewReader(buf), req) return http.ReadResponse(bufio.NewReader(buf), req)
} }
func BasicAuthHeader(login, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString(
[]byte(login+":"+password))
}
func WrapStringToCb(s string) func() (string, error) {
return func() (string, error) {
return s, nil
}
}

View file

@ -1,23 +1,33 @@
package main package handler
import ( import (
"bufio"
"context"
"errors"
"fmt" "fmt"
"io"
"net"
"net/http" "net/http"
"strings" "strings"
"sync"
"time" "time"
"github.com/Snawoot/opera-proxy/dialer"
clog "github.com/Snawoot/opera-proxy/log"
) )
const BAD_REQ_MSG = "Bad Request\n" const (
COPY_BUF = 128 * 1024
type AuthProvider func() string BAD_REQ_MSG = "Bad Request\n"
)
type ProxyHandler struct { type ProxyHandler struct {
logger *CondLogger logger *clog.CondLogger
dialer ContextDialer dialer dialer.ContextDialer
httptransport http.RoundTripper httptransport http.RoundTripper
} }
func NewProxyHandler(dialer ContextDialer, logger *CondLogger) *ProxyHandler { func NewProxyHandler(dialer dialer.ContextDialer, logger *clog.CondLogger) *ProxyHandler {
httptransport := &http.Transport{ httptransport := &http.Transport{
MaxIdleConns: 100, MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: 90 * time.Second,
@ -104,3 +114,128 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
s.HandleRequest(wr, req) s.HandleRequest(wr, req)
} }
} }
func proxy(ctx context.Context, left, right net.Conn) {
wg := sync.WaitGroup{}
cpy := func(dst, src net.Conn) {
defer wg.Done()
io.Copy(dst, src)
dst.Close()
}
wg.Add(2)
go cpy(left, right)
go cpy(right, left)
groupdone := make(chan struct{})
go func() {
wg.Wait()
groupdone <- struct{}{}
}()
select {
case <-ctx.Done():
left.Close()
right.Close()
case <-groupdone:
return
}
<-groupdone
return
}
func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
wg := sync.WaitGroup{}
ltr := func(dst net.Conn, src io.Reader) {
defer wg.Done()
io.Copy(dst, src)
dst.Close()
}
rtl := func(dst io.Writer, src io.Reader) {
defer wg.Done()
copyBody(dst, src)
}
wg.Add(2)
go ltr(right, leftreader)
go rtl(leftwriter, right)
groupdone := make(chan struct{}, 1)
go func() {
wg.Wait()
groupdone <- struct{}{}
}()
select {
case <-ctx.Done():
leftreader.Close()
right.Close()
case <-groupdone:
return
}
<-groupdone
return
}
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Connection",
"Te", // canonicalized version of "TE"
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func delHopHeaders(header http.Header) {
for _, h := range hopHeaders {
header.Del(h)
}
}
func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) {
hj, ok := hijackable.(http.Hijacker)
if !ok {
return nil, nil, errors.New("Connection doesn't support hijacking")
}
conn, rw, err := hj.Hijack()
if err != nil {
return nil, nil, err
}
var emptytime time.Time
err = conn.SetDeadline(emptytime)
if err != nil {
conn.Close()
return nil, nil, err
}
return conn, rw, nil
}
func flush(flusher interface{}) bool {
f, ok := flusher.(http.Flusher)
if !ok {
return false
}
f.Flush()
return true
}
func copyBody(wr io.Writer, body io.Reader) {
buf := make([]byte, COPY_BUF)
for {
bread, read_err := body.Read(buf)
var write_err error
if bread > 0 {
_, write_err = wr.Write(buf[:bread])
flush(wr)
}
if read_err != nil || write_err != nil {
break
}
}
}

View file

@ -1,4 +1,4 @@
package main package log
import ( import (
"fmt" "fmt"

View file

@ -1,4 +1,4 @@
package main package log
import ( import (
"errors" "errors"

49
main.go
View file

@ -22,6 +22,10 @@ import (
xproxy "golang.org/x/net/proxy" xproxy "golang.org/x/net/proxy"
"github.com/Snawoot/opera-proxy/clock"
"github.com/Snawoot/opera-proxy/dialer"
"github.com/Snawoot/opera-proxy/handler"
clog "github.com/Snawoot/opera-proxy/log"
se "github.com/Snawoot/opera-proxy/seclient" se "github.com/Snawoot/opera-proxy/seclient"
) )
@ -151,12 +155,12 @@ func parse_args() *CLIArgs {
} }
func proxyFromURLWrapper(u *url.URL, next xproxy.Dialer) (xproxy.Dialer, error) { func proxyFromURLWrapper(u *url.URL, next xproxy.Dialer) (xproxy.Dialer, error) {
cdialer, ok := next.(ContextDialer) cdialer, ok := next.(dialer.ContextDialer)
if !ok { if !ok {
return nil, errors.New("only context dialers are accepted") return nil, errors.New("only context dialers are accepted")
} }
return ProxyDialerFromURL(u, cdialer) return dialer.ProxyDialerFromURL(u, cdialer)
} }
func run() int { func run() int {
@ -166,19 +170,19 @@ func run() int {
return 0 return 0
} }
logWriter := NewLogWriter(os.Stderr) logWriter := clog.NewLogWriter(os.Stderr)
defer logWriter.Close() defer logWriter.Close()
mainLogger := NewCondLogger(log.New(logWriter, "MAIN : ", mainLogger := clog.NewCondLogger(log.New(logWriter, "MAIN : ",
log.LstdFlags|log.Lshortfile), log.LstdFlags|log.Lshortfile),
args.verbosity) args.verbosity)
proxyLogger := NewCondLogger(log.New(logWriter, "PROXY : ", proxyLogger := clog.NewCondLogger(log.New(logWriter, "PROXY : ",
log.LstdFlags|log.Lshortfile), log.LstdFlags|log.Lshortfile),
args.verbosity) args.verbosity)
mainLogger.Info("opera-proxy client version %s is starting...", version) mainLogger.Info("opera-proxy client version %s is starting...", version)
var dialer ContextDialer = &net.Dialer{ var d dialer.ContextDialer = &net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
} }
@ -191,22 +195,22 @@ func run() int {
mainLogger.Critical("Unable to parse base proxy URL: %v", err) mainLogger.Critical("Unable to parse base proxy URL: %v", err)
return 6 return 6
} }
pxDialer, err := xproxy.FromURL(proxyURL, dialer) pxDialer, err := xproxy.FromURL(proxyURL, d)
if err != nil { if err != nil {
mainLogger.Critical("Unable to instantiate base proxy dialer: %v", err) mainLogger.Critical("Unable to instantiate base proxy dialer: %v", err)
return 7 return 7
} }
dialer = pxDialer.(ContextDialer) d = pxDialer.(dialer.ContextDialer)
} }
seclientDialer := dialer seclientDialer := d
if args.apiAddress != "" || len(args.bootstrapDNS.values) > 0 { if args.apiAddress != "" || len(args.bootstrapDNS.values) > 0 {
var apiAddress string var apiAddress string
if args.apiAddress != "" { if args.apiAddress != "" {
apiAddress = args.apiAddress apiAddress = args.apiAddress
mainLogger.Info("Using fixed API host IP address = %s", apiAddress) mainLogger.Info("Using fixed API host IP address = %s", apiAddress)
} else { } else {
resolver, err := NewResolver(args.bootstrapDNS.values, args.timeout) resolver, err := dialer.NewResolver(args.bootstrapDNS.values, args.timeout)
if err != nil { if err != nil {
mainLogger.Critical("Unable to instantiate DNS resolver: %v", err) mainLogger.Critical("Unable to instantiate DNS resolver: %v", err)
return 4 return 4
@ -234,7 +238,7 @@ func run() int {
apiAddress = addrs[0].String() apiAddress = addrs[0].String()
mainLogger.Info("Discovered address of API host = %s", apiAddress) mainLogger.Info("Discovered address of API host = %s", apiAddress)
} }
seclientDialer = NewFixedDialer(apiAddress, dialer) seclientDialer = dialer.NewFixedDialer(apiAddress, d)
} }
// Dialing w/o SNI, receiving self-signed certificate, so skip verification. // Dialing w/o SNI, receiving self-signed certificate, so skip verification.
@ -303,7 +307,7 @@ func run() int {
return 13 return 13
} }
runTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error { clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error {
mainLogger.Info("Refreshing login...") mainLogger.Info("Refreshing login...")
reqCtx, cl := context.WithTimeout(ctx, args.timeout) reqCtx, cl := context.WithTimeout(ctx, args.timeout)
defer cl() defer cl()
@ -327,9 +331,6 @@ func run() int {
}) })
endpoint := ips[0] endpoint := ips[0]
auth := func() string {
return basic_auth_header(seclient.GetProxyCredentials())
}
var caPool *x509.CertPool var caPool *x509.CertPool
if args.caFile != "" { if args.caFile != "" {
@ -345,18 +346,26 @@ func run() int {
} }
} }
handlerDialer := NewProxyDialer(endpoint.NetAddr(), fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX), auth, args.certChainWorkaround, caPool, dialer) handlerDialer := dialer.NewProxyDialer(
dialer.WrapStringToCb(endpoint.NetAddr()),
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
func() (string, error) {
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
},
args.certChainWorkaround,
caPool,
d)
mainLogger.Info("Endpoint: %s", endpoint.NetAddr()) mainLogger.Info("Endpoint: %s", endpoint.NetAddr())
mainLogger.Info("Starting proxy server...") mainLogger.Info("Starting proxy server...")
handler := NewProxyHandler(handlerDialer, proxyLogger) h := handler.NewProxyHandler(handlerDialer, proxyLogger)
mainLogger.Info("Init complete.") mainLogger.Info("Init complete.")
err = http.ListenAndServe(args.bindAddress, handler) err = http.ListenAndServe(args.bindAddress, h)
mainLogger.Critical("Server terminated with a reason: %v", err) mainLogger.Critical("Server terminated with a reason: %v", err)
mainLogger.Info("Shutting down...") mainLogger.Info("Shutting down...")
return 0 return 0
} }
func printCountries(logger *CondLogger, timeout time.Duration, seclient *se.SEClient) int { func printCountries(logger *clog.CondLogger, timeout time.Duration, seclient *se.SEClient) int {
ctx, cl := context.WithTimeout(context.Background(), timeout) ctx, cl := context.WithTimeout(context.Background(), timeout)
defer cl() defer cl()
list, err := seclient.GeoList(ctx) list, err := seclient.GeoList(ctx)
@ -380,7 +389,7 @@ func printProxies(ips []se.SEIPEntry, seclient *se.SEClient) int {
login, password := seclient.GetProxyCredentials() login, password := seclient.GetProxyCredentials()
fmt.Println("Proxy login:", login) fmt.Println("Proxy login:", login)
fmt.Println("Proxy password:", password) fmt.Println("Proxy password:", password)
fmt.Println("Proxy-Authorization:", basic_auth_header(login, password)) fmt.Println("Proxy-Authorization:", dialer.BasicAuthHeader(login, password))
fmt.Println("") fmt.Println("")
wr.Write([]string{"host", "ip_address", "port"}) wr.Write([]string{"host", "ip_address", "port"})
for i, ip := range ips { for i, ip := range ips {

190
utils.go
View file

@ -1,190 +0,0 @@
package main
import (
"bufio"
"context"
"encoding/base64"
"errors"
"io"
"net"
"net/http"
"sync"
"time"
)
const (
COPY_BUF = 128 * 1024
WALLCLOCK_PRECISION = 1 * time.Second
)
func basic_auth_header(login, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString(
[]byte(login+":"+password))
}
func proxy(ctx context.Context, left, right net.Conn) {
wg := sync.WaitGroup{}
cpy := func(dst, src net.Conn) {
defer wg.Done()
io.Copy(dst, src)
dst.Close()
}
wg.Add(2)
go cpy(left, right)
go cpy(right, left)
groupdone := make(chan struct{})
go func() {
wg.Wait()
groupdone <- struct{}{}
}()
select {
case <-ctx.Done():
left.Close()
right.Close()
case <-groupdone:
return
}
<-groupdone
return
}
func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
wg := sync.WaitGroup{}
ltr := func(dst net.Conn, src io.Reader) {
defer wg.Done()
io.Copy(dst, src)
dst.Close()
}
rtl := func(dst io.Writer, src io.Reader) {
defer wg.Done()
copyBody(dst, src)
}
wg.Add(2)
go ltr(right, leftreader)
go rtl(leftwriter, right)
groupdone := make(chan struct{}, 1)
go func() {
wg.Wait()
groupdone <- struct{}{}
}()
select {
case <-ctx.Done():
leftreader.Close()
right.Close()
case <-groupdone:
return
}
<-groupdone
return
}
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Connection",
"Te", // canonicalized version of "TE"
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func delHopHeaders(header http.Header) {
for _, h := range hopHeaders {
header.Del(h)
}
}
func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) {
hj, ok := hijackable.(http.Hijacker)
if !ok {
return nil, nil, errors.New("Connection doesn't support hijacking")
}
conn, rw, err := hj.Hijack()
if err != nil {
return nil, nil, err
}
var emptytime time.Time
err = conn.SetDeadline(emptytime)
if err != nil {
conn.Close()
return nil, nil, err
}
return conn, rw, nil
}
func flush(flusher interface{}) bool {
f, ok := flusher.(http.Flusher)
if !ok {
return false
}
f.Flush()
return true
}
func copyBody(wr io.Writer, body io.Reader) {
buf := make([]byte, COPY_BUF)
for {
bread, read_err := body.Read(buf)
var write_err error
if bread > 0 {
_, write_err = wr.Write(buf[:bread])
flush(wr)
}
if read_err != nil || write_err != nil {
break
}
}
}
func AfterWallClock(d time.Duration) <-chan time.Time {
ch := make(chan time.Time, 1)
deadline := time.Now().Add(d).Truncate(0)
after_ch := time.After(d)
ticker := time.NewTicker(WALLCLOCK_PRECISION)
go func() {
var t time.Time
defer ticker.Stop()
for {
select {
case t = <-after_ch:
ch <- t
return
case t = <-ticker.C:
if t.After(deadline) {
ch <- t
return
}
}
}
}()
return ch
}
func runTicker(ctx context.Context, interval, retryInterval time.Duration, cb func(context.Context) error) {
go func() {
var err error
for {
nextInterval := interval
if err != nil {
nextInterval = retryInterval
}
select {
case <-ctx.Done():
return
case <-AfterWallClock(nextInterval):
err = cb(ctx)
}
}
}()
}