split main package into subpackages

This commit is contained in:
Vladislav Yarmak 2024-11-03 15:17:29 +02:00
parent 38b2b95dcb
commit a05aad77aa
9 changed files with 263 additions and 242 deletions

50
clock/clock.go Normal file
View file

@ -0,0 +1,50 @@
package clock
import (
"context"
"time"
)
const WALLCLOCK_PRECISION = 1 * time.Second
func AfterWallClock(d time.Duration) <-chan time.Time {
ch := make(chan time.Time, 1)
deadline := time.Now().Add(d).Truncate(0)
after_ch := time.After(d)
ticker := time.NewTicker(WALLCLOCK_PRECISION)
go func() {
var t time.Time
defer ticker.Stop()
for {
select {
case t = <-after_ch:
ch <- t
return
case t = <-ticker.C:
if t.After(deadline) {
ch <- t
return
}
}
}
}()
return ch
}
func RunTicker(ctx context.Context, interval, retryInterval time.Duration, cb func(context.Context) error) {
go func() {
var err error
for {
nextInterval := interval
if err != nil {
nextInterval = retryInterval
}
select {
case <-ctx.Done():
return
case <-AfterWallClock(nextInterval):
err = cb(ctx)
}
}
}()
}

View file

@ -1,4 +1,4 @@
package main
package dialer
import (
"context"

View file

@ -1,4 +1,4 @@
package main
package dialer
import (
"context"

View file

@ -1,4 +1,4 @@
package main
package dialer
import (
"bufio"
@ -6,6 +6,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
@ -47,11 +48,11 @@ CV4Ks2dH/hzg1cEo70qLRDEmBDeNiXQ2Lu+lIg+DdEmSx/cQwgwp+7e9un/jX9Wf
`
)
var UpstreamBlockedError = errors.New("blocked by upstream")
var missingLinkDER, _ = pem.Decode([]byte(MISSING_CHAIN_CERT))
var missingLink, _ = x509.ParseCertificate(missingLinkDER.Bytes)
type stringCb = func() (string, error)
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}
@ -62,15 +63,15 @@ type ContextDialer interface {
}
type ProxyDialer struct {
address string
tlsServerName string
auth AuthProvider
address stringCb
tlsServerName stringCb
auth stringCb
next ContextDialer
intermediateWorkaround bool
caPool *x509.CertPool
}
func NewProxyDialer(address, tlsServerName string, auth AuthProvider, intermediateWorkaround bool, caPool *x509.CertPool, nextDialer ContextDialer) *ProxyDialer {
func NewProxyDialer(address, tlsServerName, auth stringCb, intermediateWorkaround bool, caPool *x509.CertPool, nextDialer ContextDialer) *ProxyDialer {
return &ProxyDialer{
address: address,
tlsServerName: tlsServerName,
@ -85,7 +86,7 @@ func ProxyDialerFromURL(u *url.URL, next ContextDialer) (*ProxyDialer, error) {
host := u.Hostname()
port := u.Port()
tlsServerName := ""
var auth AuthProvider = nil
var auth stringCb = nil
switch strings.ToLower(u.Scheme) {
case "http":
@ -106,12 +107,9 @@ func ProxyDialerFromURL(u *url.URL, next ContextDialer) (*ProxyDialer, error) {
if u.User != nil {
username := u.User.Username()
password, _ := u.User.Password()
authHeader := basic_auth_header(username, password)
auth = func() string {
return authHeader
}
auth = WrapStringToCb(BasicAuthHeader(username, password))
}
return NewProxyDialer(address, tlsServerName, auth, false, nil, next), nil
return NewProxyDialer(WrapStringToCb(address), WrapStringToCb(tlsServerName), auth, false, nil, next), nil
}
func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
@ -121,12 +119,20 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
return nil, errors.New("bad network specified for DialContext: only tcp is supported")
}
conn, err := d.next.DialContext(ctx, "tcp", d.address)
uAddress, err := d.address()
if err != nil {
return nil, err
}
conn, err := d.next.DialContext(ctx, "tcp", uAddress)
if err != nil {
return nil, err
}
if d.tlsServerName != "" {
uTLSServerName, err := d.tlsServerName()
if err != nil {
return nil, err
}
if uTLSServerName != "" {
// Custom cert verification logic:
// DO NOT send SNI extension of TLS ClientHello
// DO peer certificate verification against specified servername
@ -135,7 +141,7 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
InsecureSkipVerify: true,
VerifyConnection: func(cs tls.ConnectionState) error {
opts := x509.VerifyOptions{
DNSName: d.tlsServerName,
DNSName: uTLSServerName,
Intermediates: x509.NewCertPool(),
Roots: d.caPool,
}
@ -169,7 +175,11 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
}
if d.auth != nil {
req.Header.Set(PROXY_AUTHORIZATION_HEADER, d.auth())
auth, err := d.auth()
if err != nil {
return nil, err
}
req.Header.Set(PROXY_AUTHORIZATION_HEADER, auth)
}
rawreq, err := httputil.DumpRequest(req, false)
@ -188,10 +198,6 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
}
if proxyResp.StatusCode != http.StatusOK {
if proxyResp.StatusCode == http.StatusForbidden &&
proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" {
return nil, UpstreamBlockedError
}
return nil, errors.New(fmt.Sprintf("bad response from upstream proxy server: %s", proxyResp.Status))
}
@ -228,3 +234,14 @@ func readResponse(r io.Reader, req *http.Request) (*http.Response, error) {
}
return http.ReadResponse(bufio.NewReader(buf), req)
}
func BasicAuthHeader(login, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString(
[]byte(login+":"+password))
}
func WrapStringToCb(s string) func() (string, error) {
return func() (string, error) {
return s, nil
}
}

View file

@ -1,23 +1,33 @@
package main
package handler
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/Snawoot/opera-proxy/dialer"
clog "github.com/Snawoot/opera-proxy/log"
)
const BAD_REQ_MSG = "Bad Request\n"
type AuthProvider func() string
const (
COPY_BUF = 128 * 1024
BAD_REQ_MSG = "Bad Request\n"
)
type ProxyHandler struct {
logger *CondLogger
dialer ContextDialer
logger *clog.CondLogger
dialer dialer.ContextDialer
httptransport http.RoundTripper
}
func NewProxyHandler(dialer ContextDialer, logger *CondLogger) *ProxyHandler {
func NewProxyHandler(dialer dialer.ContextDialer, logger *clog.CondLogger) *ProxyHandler {
httptransport := &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
@ -104,3 +114,128 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
s.HandleRequest(wr, req)
}
}
func proxy(ctx context.Context, left, right net.Conn) {
wg := sync.WaitGroup{}
cpy := func(dst, src net.Conn) {
defer wg.Done()
io.Copy(dst, src)
dst.Close()
}
wg.Add(2)
go cpy(left, right)
go cpy(right, left)
groupdone := make(chan struct{})
go func() {
wg.Wait()
groupdone <- struct{}{}
}()
select {
case <-ctx.Done():
left.Close()
right.Close()
case <-groupdone:
return
}
<-groupdone
return
}
func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
wg := sync.WaitGroup{}
ltr := func(dst net.Conn, src io.Reader) {
defer wg.Done()
io.Copy(dst, src)
dst.Close()
}
rtl := func(dst io.Writer, src io.Reader) {
defer wg.Done()
copyBody(dst, src)
}
wg.Add(2)
go ltr(right, leftreader)
go rtl(leftwriter, right)
groupdone := make(chan struct{}, 1)
go func() {
wg.Wait()
groupdone <- struct{}{}
}()
select {
case <-ctx.Done():
leftreader.Close()
right.Close()
case <-groupdone:
return
}
<-groupdone
return
}
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Connection",
"Te", // canonicalized version of "TE"
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func delHopHeaders(header http.Header) {
for _, h := range hopHeaders {
header.Del(h)
}
}
func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) {
hj, ok := hijackable.(http.Hijacker)
if !ok {
return nil, nil, errors.New("Connection doesn't support hijacking")
}
conn, rw, err := hj.Hijack()
if err != nil {
return nil, nil, err
}
var emptytime time.Time
err = conn.SetDeadline(emptytime)
if err != nil {
conn.Close()
return nil, nil, err
}
return conn, rw, nil
}
func flush(flusher interface{}) bool {
f, ok := flusher.(http.Flusher)
if !ok {
return false
}
f.Flush()
return true
}
func copyBody(wr io.Writer, body io.Reader) {
buf := make([]byte, COPY_BUF)
for {
bread, read_err := body.Read(buf)
var write_err error
if bread > 0 {
_, write_err = wr.Write(buf[:bread])
flush(wr)
}
if read_err != nil || write_err != nil {
break
}
}
}

View file

@ -1,4 +1,4 @@
package main
package log
import (
"fmt"

View file

@ -1,4 +1,4 @@
package main
package log
import (
"errors"

49
main.go
View file

@ -22,6 +22,10 @@ import (
xproxy "golang.org/x/net/proxy"
"github.com/Snawoot/opera-proxy/clock"
"github.com/Snawoot/opera-proxy/dialer"
"github.com/Snawoot/opera-proxy/handler"
clog "github.com/Snawoot/opera-proxy/log"
se "github.com/Snawoot/opera-proxy/seclient"
)
@ -151,12 +155,12 @@ func parse_args() *CLIArgs {
}
func proxyFromURLWrapper(u *url.URL, next xproxy.Dialer) (xproxy.Dialer, error) {
cdialer, ok := next.(ContextDialer)
cdialer, ok := next.(dialer.ContextDialer)
if !ok {
return nil, errors.New("only context dialers are accepted")
}
return ProxyDialerFromURL(u, cdialer)
return dialer.ProxyDialerFromURL(u, cdialer)
}
func run() int {
@ -166,19 +170,19 @@ func run() int {
return 0
}
logWriter := NewLogWriter(os.Stderr)
logWriter := clog.NewLogWriter(os.Stderr)
defer logWriter.Close()
mainLogger := NewCondLogger(log.New(logWriter, "MAIN : ",
mainLogger := clog.NewCondLogger(log.New(logWriter, "MAIN : ",
log.LstdFlags|log.Lshortfile),
args.verbosity)
proxyLogger := NewCondLogger(log.New(logWriter, "PROXY : ",
proxyLogger := clog.NewCondLogger(log.New(logWriter, "PROXY : ",
log.LstdFlags|log.Lshortfile),
args.verbosity)
mainLogger.Info("opera-proxy client version %s is starting...", version)
var dialer ContextDialer = &net.Dialer{
var d dialer.ContextDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
@ -191,22 +195,22 @@ func run() int {
mainLogger.Critical("Unable to parse base proxy URL: %v", err)
return 6
}
pxDialer, err := xproxy.FromURL(proxyURL, dialer)
pxDialer, err := xproxy.FromURL(proxyURL, d)
if err != nil {
mainLogger.Critical("Unable to instantiate base proxy dialer: %v", err)
return 7
}
dialer = pxDialer.(ContextDialer)
d = pxDialer.(dialer.ContextDialer)
}
seclientDialer := dialer
seclientDialer := d
if args.apiAddress != "" || len(args.bootstrapDNS.values) > 0 {
var apiAddress string
if args.apiAddress != "" {
apiAddress = args.apiAddress
mainLogger.Info("Using fixed API host IP address = %s", apiAddress)
} else {
resolver, err := NewResolver(args.bootstrapDNS.values, args.timeout)
resolver, err := dialer.NewResolver(args.bootstrapDNS.values, args.timeout)
if err != nil {
mainLogger.Critical("Unable to instantiate DNS resolver: %v", err)
return 4
@ -234,7 +238,7 @@ func run() int {
apiAddress = addrs[0].String()
mainLogger.Info("Discovered address of API host = %s", apiAddress)
}
seclientDialer = NewFixedDialer(apiAddress, dialer)
seclientDialer = dialer.NewFixedDialer(apiAddress, d)
}
// Dialing w/o SNI, receiving self-signed certificate, so skip verification.
@ -303,7 +307,7 @@ func run() int {
return 13
}
runTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error {
clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error {
mainLogger.Info("Refreshing login...")
reqCtx, cl := context.WithTimeout(ctx, args.timeout)
defer cl()
@ -327,9 +331,6 @@ func run() int {
})
endpoint := ips[0]
auth := func() string {
return basic_auth_header(seclient.GetProxyCredentials())
}
var caPool *x509.CertPool
if args.caFile != "" {
@ -345,18 +346,26 @@ func run() int {
}
}
handlerDialer := NewProxyDialer(endpoint.NetAddr(), fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX), auth, args.certChainWorkaround, caPool, dialer)
handlerDialer := dialer.NewProxyDialer(
dialer.WrapStringToCb(endpoint.NetAddr()),
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
func() (string, error) {
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
},
args.certChainWorkaround,
caPool,
d)
mainLogger.Info("Endpoint: %s", endpoint.NetAddr())
mainLogger.Info("Starting proxy server...")
handler := NewProxyHandler(handlerDialer, proxyLogger)
h := handler.NewProxyHandler(handlerDialer, proxyLogger)
mainLogger.Info("Init complete.")
err = http.ListenAndServe(args.bindAddress, handler)
err = http.ListenAndServe(args.bindAddress, h)
mainLogger.Critical("Server terminated with a reason: %v", err)
mainLogger.Info("Shutting down...")
return 0
}
func printCountries(logger *CondLogger, timeout time.Duration, seclient *se.SEClient) int {
func printCountries(logger *clog.CondLogger, timeout time.Duration, seclient *se.SEClient) int {
ctx, cl := context.WithTimeout(context.Background(), timeout)
defer cl()
list, err := seclient.GeoList(ctx)
@ -380,7 +389,7 @@ func printProxies(ips []se.SEIPEntry, seclient *se.SEClient) int {
login, password := seclient.GetProxyCredentials()
fmt.Println("Proxy login:", login)
fmt.Println("Proxy password:", password)
fmt.Println("Proxy-Authorization:", basic_auth_header(login, password))
fmt.Println("Proxy-Authorization:", dialer.BasicAuthHeader(login, password))
fmt.Println("")
wr.Write([]string{"host", "ip_address", "port"})
for i, ip := range ips {

190
utils.go
View file

@ -1,190 +0,0 @@
package main
import (
"bufio"
"context"
"encoding/base64"
"errors"
"io"
"net"
"net/http"
"sync"
"time"
)
const (
COPY_BUF = 128 * 1024
WALLCLOCK_PRECISION = 1 * time.Second
)
func basic_auth_header(login, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString(
[]byte(login+":"+password))
}
func proxy(ctx context.Context, left, right net.Conn) {
wg := sync.WaitGroup{}
cpy := func(dst, src net.Conn) {
defer wg.Done()
io.Copy(dst, src)
dst.Close()
}
wg.Add(2)
go cpy(left, right)
go cpy(right, left)
groupdone := make(chan struct{})
go func() {
wg.Wait()
groupdone <- struct{}{}
}()
select {
case <-ctx.Done():
left.Close()
right.Close()
case <-groupdone:
return
}
<-groupdone
return
}
func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
wg := sync.WaitGroup{}
ltr := func(dst net.Conn, src io.Reader) {
defer wg.Done()
io.Copy(dst, src)
dst.Close()
}
rtl := func(dst io.Writer, src io.Reader) {
defer wg.Done()
copyBody(dst, src)
}
wg.Add(2)
go ltr(right, leftreader)
go rtl(leftwriter, right)
groupdone := make(chan struct{}, 1)
go func() {
wg.Wait()
groupdone <- struct{}{}
}()
select {
case <-ctx.Done():
leftreader.Close()
right.Close()
case <-groupdone:
return
}
<-groupdone
return
}
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Connection",
"Te", // canonicalized version of "TE"
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func delHopHeaders(header http.Header) {
for _, h := range hopHeaders {
header.Del(h)
}
}
func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) {
hj, ok := hijackable.(http.Hijacker)
if !ok {
return nil, nil, errors.New("Connection doesn't support hijacking")
}
conn, rw, err := hj.Hijack()
if err != nil {
return nil, nil, err
}
var emptytime time.Time
err = conn.SetDeadline(emptytime)
if err != nil {
conn.Close()
return nil, nil, err
}
return conn, rw, nil
}
func flush(flusher interface{}) bool {
f, ok := flusher.(http.Flusher)
if !ok {
return false
}
f.Flush()
return true
}
func copyBody(wr io.Writer, body io.Reader) {
buf := make([]byte, COPY_BUF)
for {
bread, read_err := body.Read(buf)
var write_err error
if bread > 0 {
_, write_err = wr.Write(buf[:bread])
flush(wr)
}
if read_err != nil || write_err != nil {
break
}
}
}
func AfterWallClock(d time.Duration) <-chan time.Time {
ch := make(chan time.Time, 1)
deadline := time.Now().Add(d).Truncate(0)
after_ch := time.After(d)
ticker := time.NewTicker(WALLCLOCK_PRECISION)
go func() {
var t time.Time
defer ticker.Stop()
for {
select {
case t = <-after_ch:
ch <- t
return
case t = <-ticker.C:
if t.After(deadline) {
ch <- t
return
}
}
}
}()
return ch
}
func runTicker(ctx context.Context, interval, retryInterval time.Duration, cb func(context.Context) error) {
go func() {
var err error
for {
nextInterval := interval
if err != nil {
nextInterval = retryInterval
}
select {
case <-ctx.Done():
return
case <-AfterWallClock(nextInterval):
err = cb(ctx)
}
}
}()
}