mirror of
https://github.com/Snawoot/opera-proxy.git
synced 2025-09-01 18:20:23 +00:00
split main package into subpackages
This commit is contained in:
parent
38b2b95dcb
commit
a05aad77aa
9 changed files with 263 additions and 242 deletions
50
clock/clock.go
Normal file
50
clock/clock.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package clock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
const WALLCLOCK_PRECISION = 1 * time.Second
|
||||
|
||||
func AfterWallClock(d time.Duration) <-chan time.Time {
|
||||
ch := make(chan time.Time, 1)
|
||||
deadline := time.Now().Add(d).Truncate(0)
|
||||
after_ch := time.After(d)
|
||||
ticker := time.NewTicker(WALLCLOCK_PRECISION)
|
||||
go func() {
|
||||
var t time.Time
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case t = <-after_ch:
|
||||
ch <- t
|
||||
return
|
||||
case t = <-ticker.C:
|
||||
if t.After(deadline) {
|
||||
ch <- t
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func RunTicker(ctx context.Context, interval, retryInterval time.Duration, cb func(context.Context) error) {
|
||||
go func() {
|
||||
var err error
|
||||
for {
|
||||
nextInterval := interval
|
||||
if err != nil {
|
||||
nextInterval = retryInterval
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-AfterWallClock(nextInterval):
|
||||
err = cb(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
|
@ -6,6 +6,7 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -47,11 +48,11 @@ CV4Ks2dH/hzg1cEo70qLRDEmBDeNiXQ2Lu+lIg+DdEmSx/cQwgwp+7e9un/jX9Wf
|
|||
`
|
||||
)
|
||||
|
||||
var UpstreamBlockedError = errors.New("blocked by upstream")
|
||||
|
||||
var missingLinkDER, _ = pem.Decode([]byte(MISSING_CHAIN_CERT))
|
||||
var missingLink, _ = x509.ParseCertificate(missingLinkDER.Bytes)
|
||||
|
||||
type stringCb = func() (string, error)
|
||||
|
||||
type Dialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
}
|
||||
|
@ -62,15 +63,15 @@ type ContextDialer interface {
|
|||
}
|
||||
|
||||
type ProxyDialer struct {
|
||||
address string
|
||||
tlsServerName string
|
||||
auth AuthProvider
|
||||
address stringCb
|
||||
tlsServerName stringCb
|
||||
auth stringCb
|
||||
next ContextDialer
|
||||
intermediateWorkaround bool
|
||||
caPool *x509.CertPool
|
||||
}
|
||||
|
||||
func NewProxyDialer(address, tlsServerName string, auth AuthProvider, intermediateWorkaround bool, caPool *x509.CertPool, nextDialer ContextDialer) *ProxyDialer {
|
||||
func NewProxyDialer(address, tlsServerName, auth stringCb, intermediateWorkaround bool, caPool *x509.CertPool, nextDialer ContextDialer) *ProxyDialer {
|
||||
return &ProxyDialer{
|
||||
address: address,
|
||||
tlsServerName: tlsServerName,
|
||||
|
@ -85,7 +86,7 @@ func ProxyDialerFromURL(u *url.URL, next ContextDialer) (*ProxyDialer, error) {
|
|||
host := u.Hostname()
|
||||
port := u.Port()
|
||||
tlsServerName := ""
|
||||
var auth AuthProvider = nil
|
||||
var auth stringCb = nil
|
||||
|
||||
switch strings.ToLower(u.Scheme) {
|
||||
case "http":
|
||||
|
@ -106,12 +107,9 @@ func ProxyDialerFromURL(u *url.URL, next ContextDialer) (*ProxyDialer, error) {
|
|||
if u.User != nil {
|
||||
username := u.User.Username()
|
||||
password, _ := u.User.Password()
|
||||
authHeader := basic_auth_header(username, password)
|
||||
auth = func() string {
|
||||
return authHeader
|
||||
}
|
||||
auth = WrapStringToCb(BasicAuthHeader(username, password))
|
||||
}
|
||||
return NewProxyDialer(address, tlsServerName, auth, false, nil, next), nil
|
||||
return NewProxyDialer(WrapStringToCb(address), WrapStringToCb(tlsServerName), auth, false, nil, next), nil
|
||||
}
|
||||
|
||||
func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
|
@ -121,12 +119,20 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
|
|||
return nil, errors.New("bad network specified for DialContext: only tcp is supported")
|
||||
}
|
||||
|
||||
conn, err := d.next.DialContext(ctx, "tcp", d.address)
|
||||
uAddress, err := d.address()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := d.next.DialContext(ctx, "tcp", uAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if d.tlsServerName != "" {
|
||||
uTLSServerName, err := d.tlsServerName()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if uTLSServerName != "" {
|
||||
// Custom cert verification logic:
|
||||
// DO NOT send SNI extension of TLS ClientHello
|
||||
// DO peer certificate verification against specified servername
|
||||
|
@ -135,7 +141,7 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
|
|||
InsecureSkipVerify: true,
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
opts := x509.VerifyOptions{
|
||||
DNSName: d.tlsServerName,
|
||||
DNSName: uTLSServerName,
|
||||
Intermediates: x509.NewCertPool(),
|
||||
Roots: d.caPool,
|
||||
}
|
||||
|
@ -169,7 +175,11 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
|
|||
}
|
||||
|
||||
if d.auth != nil {
|
||||
req.Header.Set(PROXY_AUTHORIZATION_HEADER, d.auth())
|
||||
auth, err := d.auth()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set(PROXY_AUTHORIZATION_HEADER, auth)
|
||||
}
|
||||
|
||||
rawreq, err := httputil.DumpRequest(req, false)
|
||||
|
@ -188,10 +198,6 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string)
|
|||
}
|
||||
|
||||
if proxyResp.StatusCode != http.StatusOK {
|
||||
if proxyResp.StatusCode == http.StatusForbidden &&
|
||||
proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" {
|
||||
return nil, UpstreamBlockedError
|
||||
}
|
||||
return nil, errors.New(fmt.Sprintf("bad response from upstream proxy server: %s", proxyResp.Status))
|
||||
}
|
||||
|
||||
|
@ -228,3 +234,14 @@ func readResponse(r io.Reader, req *http.Request) (*http.Response, error) {
|
|||
}
|
||||
return http.ReadResponse(bufio.NewReader(buf), req)
|
||||
}
|
||||
|
||||
func BasicAuthHeader(login, password string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString(
|
||||
[]byte(login+":"+password))
|
||||
}
|
||||
|
||||
func WrapStringToCb(s string) func() (string, error) {
|
||||
return func() (string, error) {
|
||||
return s, nil
|
||||
}
|
||||
}
|
|
@ -1,23 +1,33 @@
|
|||
package main
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Snawoot/opera-proxy/dialer"
|
||||
clog "github.com/Snawoot/opera-proxy/log"
|
||||
)
|
||||
|
||||
const BAD_REQ_MSG = "Bad Request\n"
|
||||
|
||||
type AuthProvider func() string
|
||||
const (
|
||||
COPY_BUF = 128 * 1024
|
||||
BAD_REQ_MSG = "Bad Request\n"
|
||||
)
|
||||
|
||||
type ProxyHandler struct {
|
||||
logger *CondLogger
|
||||
dialer ContextDialer
|
||||
logger *clog.CondLogger
|
||||
dialer dialer.ContextDialer
|
||||
httptransport http.RoundTripper
|
||||
}
|
||||
|
||||
func NewProxyHandler(dialer ContextDialer, logger *CondLogger) *ProxyHandler {
|
||||
func NewProxyHandler(dialer dialer.ContextDialer, logger *clog.CondLogger) *ProxyHandler {
|
||||
httptransport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
|
@ -104,3 +114,128 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
|||
s.HandleRequest(wr, req)
|
||||
}
|
||||
}
|
||||
|
||||
func proxy(ctx context.Context, left, right net.Conn) {
|
||||
wg := sync.WaitGroup{}
|
||||
cpy := func(dst, src net.Conn) {
|
||||
defer wg.Done()
|
||||
io.Copy(dst, src)
|
||||
dst.Close()
|
||||
}
|
||||
wg.Add(2)
|
||||
go cpy(left, right)
|
||||
go cpy(right, left)
|
||||
groupdone := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
groupdone <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
left.Close()
|
||||
right.Close()
|
||||
case <-groupdone:
|
||||
return
|
||||
}
|
||||
<-groupdone
|
||||
return
|
||||
}
|
||||
|
||||
func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
|
||||
wg := sync.WaitGroup{}
|
||||
ltr := func(dst net.Conn, src io.Reader) {
|
||||
defer wg.Done()
|
||||
io.Copy(dst, src)
|
||||
dst.Close()
|
||||
}
|
||||
rtl := func(dst io.Writer, src io.Reader) {
|
||||
defer wg.Done()
|
||||
copyBody(dst, src)
|
||||
}
|
||||
wg.Add(2)
|
||||
go ltr(right, leftreader)
|
||||
go rtl(leftwriter, right)
|
||||
groupdone := make(chan struct{}, 1)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
groupdone <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
leftreader.Close()
|
||||
right.Close()
|
||||
case <-groupdone:
|
||||
return
|
||||
}
|
||||
<-groupdone
|
||||
return
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Connection",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailers",
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func delHopHeaders(header http.Header) {
|
||||
for _, h := range hopHeaders {
|
||||
header.Del(h)
|
||||
}
|
||||
}
|
||||
|
||||
func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) {
|
||||
hj, ok := hijackable.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("Connection doesn't support hijacking")
|
||||
}
|
||||
conn, rw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var emptytime time.Time
|
||||
err = conn.SetDeadline(emptytime)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, rw, nil
|
||||
}
|
||||
|
||||
func flush(flusher interface{}) bool {
|
||||
f, ok := flusher.(http.Flusher)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
f.Flush()
|
||||
return true
|
||||
}
|
||||
|
||||
func copyBody(wr io.Writer, body io.Reader) {
|
||||
buf := make([]byte, COPY_BUF)
|
||||
for {
|
||||
bread, read_err := body.Read(buf)
|
||||
var write_err error
|
||||
if bread > 0 {
|
||||
_, write_err = wr.Write(buf[:bread])
|
||||
flush(wr)
|
||||
}
|
||||
if read_err != nil || write_err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package log
|
||||
|
||||
import (
|
||||
"fmt"
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package log
|
||||
|
||||
import (
|
||||
"errors"
|
49
main.go
49
main.go
|
@ -22,6 +22,10 @@ import (
|
|||
|
||||
xproxy "golang.org/x/net/proxy"
|
||||
|
||||
"github.com/Snawoot/opera-proxy/clock"
|
||||
"github.com/Snawoot/opera-proxy/dialer"
|
||||
"github.com/Snawoot/opera-proxy/handler"
|
||||
clog "github.com/Snawoot/opera-proxy/log"
|
||||
se "github.com/Snawoot/opera-proxy/seclient"
|
||||
)
|
||||
|
||||
|
@ -151,12 +155,12 @@ func parse_args() *CLIArgs {
|
|||
}
|
||||
|
||||
func proxyFromURLWrapper(u *url.URL, next xproxy.Dialer) (xproxy.Dialer, error) {
|
||||
cdialer, ok := next.(ContextDialer)
|
||||
cdialer, ok := next.(dialer.ContextDialer)
|
||||
if !ok {
|
||||
return nil, errors.New("only context dialers are accepted")
|
||||
}
|
||||
|
||||
return ProxyDialerFromURL(u, cdialer)
|
||||
return dialer.ProxyDialerFromURL(u, cdialer)
|
||||
}
|
||||
|
||||
func run() int {
|
||||
|
@ -166,19 +170,19 @@ func run() int {
|
|||
return 0
|
||||
}
|
||||
|
||||
logWriter := NewLogWriter(os.Stderr)
|
||||
logWriter := clog.NewLogWriter(os.Stderr)
|
||||
defer logWriter.Close()
|
||||
|
||||
mainLogger := NewCondLogger(log.New(logWriter, "MAIN : ",
|
||||
mainLogger := clog.NewCondLogger(log.New(logWriter, "MAIN : ",
|
||||
log.LstdFlags|log.Lshortfile),
|
||||
args.verbosity)
|
||||
proxyLogger := NewCondLogger(log.New(logWriter, "PROXY : ",
|
||||
proxyLogger := clog.NewCondLogger(log.New(logWriter, "PROXY : ",
|
||||
log.LstdFlags|log.Lshortfile),
|
||||
args.verbosity)
|
||||
|
||||
mainLogger.Info("opera-proxy client version %s is starting...", version)
|
||||
|
||||
var dialer ContextDialer = &net.Dialer{
|
||||
var d dialer.ContextDialer = &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
|
@ -191,22 +195,22 @@ func run() int {
|
|||
mainLogger.Critical("Unable to parse base proxy URL: %v", err)
|
||||
return 6
|
||||
}
|
||||
pxDialer, err := xproxy.FromURL(proxyURL, dialer)
|
||||
pxDialer, err := xproxy.FromURL(proxyURL, d)
|
||||
if err != nil {
|
||||
mainLogger.Critical("Unable to instantiate base proxy dialer: %v", err)
|
||||
return 7
|
||||
}
|
||||
dialer = pxDialer.(ContextDialer)
|
||||
d = pxDialer.(dialer.ContextDialer)
|
||||
}
|
||||
|
||||
seclientDialer := dialer
|
||||
seclientDialer := d
|
||||
if args.apiAddress != "" || len(args.bootstrapDNS.values) > 0 {
|
||||
var apiAddress string
|
||||
if args.apiAddress != "" {
|
||||
apiAddress = args.apiAddress
|
||||
mainLogger.Info("Using fixed API host IP address = %s", apiAddress)
|
||||
} else {
|
||||
resolver, err := NewResolver(args.bootstrapDNS.values, args.timeout)
|
||||
resolver, err := dialer.NewResolver(args.bootstrapDNS.values, args.timeout)
|
||||
if err != nil {
|
||||
mainLogger.Critical("Unable to instantiate DNS resolver: %v", err)
|
||||
return 4
|
||||
|
@ -234,7 +238,7 @@ func run() int {
|
|||
apiAddress = addrs[0].String()
|
||||
mainLogger.Info("Discovered address of API host = %s", apiAddress)
|
||||
}
|
||||
seclientDialer = NewFixedDialer(apiAddress, dialer)
|
||||
seclientDialer = dialer.NewFixedDialer(apiAddress, d)
|
||||
}
|
||||
|
||||
// Dialing w/o SNI, receiving self-signed certificate, so skip verification.
|
||||
|
@ -303,7 +307,7 @@ func run() int {
|
|||
return 13
|
||||
}
|
||||
|
||||
runTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error {
|
||||
clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error {
|
||||
mainLogger.Info("Refreshing login...")
|
||||
reqCtx, cl := context.WithTimeout(ctx, args.timeout)
|
||||
defer cl()
|
||||
|
@ -327,9 +331,6 @@ func run() int {
|
|||
})
|
||||
|
||||
endpoint := ips[0]
|
||||
auth := func() string {
|
||||
return basic_auth_header(seclient.GetProxyCredentials())
|
||||
}
|
||||
|
||||
var caPool *x509.CertPool
|
||||
if args.caFile != "" {
|
||||
|
@ -345,18 +346,26 @@ func run() int {
|
|||
}
|
||||
}
|
||||
|
||||
handlerDialer := NewProxyDialer(endpoint.NetAddr(), fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX), auth, args.certChainWorkaround, caPool, dialer)
|
||||
handlerDialer := dialer.NewProxyDialer(
|
||||
dialer.WrapStringToCb(endpoint.NetAddr()),
|
||||
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
|
||||
func() (string, error) {
|
||||
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
|
||||
},
|
||||
args.certChainWorkaround,
|
||||
caPool,
|
||||
d)
|
||||
mainLogger.Info("Endpoint: %s", endpoint.NetAddr())
|
||||
mainLogger.Info("Starting proxy server...")
|
||||
handler := NewProxyHandler(handlerDialer, proxyLogger)
|
||||
h := handler.NewProxyHandler(handlerDialer, proxyLogger)
|
||||
mainLogger.Info("Init complete.")
|
||||
err = http.ListenAndServe(args.bindAddress, handler)
|
||||
err = http.ListenAndServe(args.bindAddress, h)
|
||||
mainLogger.Critical("Server terminated with a reason: %v", err)
|
||||
mainLogger.Info("Shutting down...")
|
||||
return 0
|
||||
}
|
||||
|
||||
func printCountries(logger *CondLogger, timeout time.Duration, seclient *se.SEClient) int {
|
||||
func printCountries(logger *clog.CondLogger, timeout time.Duration, seclient *se.SEClient) int {
|
||||
ctx, cl := context.WithTimeout(context.Background(), timeout)
|
||||
defer cl()
|
||||
list, err := seclient.GeoList(ctx)
|
||||
|
@ -380,7 +389,7 @@ func printProxies(ips []se.SEIPEntry, seclient *se.SEClient) int {
|
|||
login, password := seclient.GetProxyCredentials()
|
||||
fmt.Println("Proxy login:", login)
|
||||
fmt.Println("Proxy password:", password)
|
||||
fmt.Println("Proxy-Authorization:", basic_auth_header(login, password))
|
||||
fmt.Println("Proxy-Authorization:", dialer.BasicAuthHeader(login, password))
|
||||
fmt.Println("")
|
||||
wr.Write([]string{"host", "ip_address", "port"})
|
||||
for i, ip := range ips {
|
||||
|
|
190
utils.go
190
utils.go
|
@ -1,190 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
COPY_BUF = 128 * 1024
|
||||
WALLCLOCK_PRECISION = 1 * time.Second
|
||||
)
|
||||
|
||||
func basic_auth_header(login, password string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString(
|
||||
[]byte(login+":"+password))
|
||||
}
|
||||
|
||||
func proxy(ctx context.Context, left, right net.Conn) {
|
||||
wg := sync.WaitGroup{}
|
||||
cpy := func(dst, src net.Conn) {
|
||||
defer wg.Done()
|
||||
io.Copy(dst, src)
|
||||
dst.Close()
|
||||
}
|
||||
wg.Add(2)
|
||||
go cpy(left, right)
|
||||
go cpy(right, left)
|
||||
groupdone := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
groupdone <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
left.Close()
|
||||
right.Close()
|
||||
case <-groupdone:
|
||||
return
|
||||
}
|
||||
<-groupdone
|
||||
return
|
||||
}
|
||||
|
||||
func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
|
||||
wg := sync.WaitGroup{}
|
||||
ltr := func(dst net.Conn, src io.Reader) {
|
||||
defer wg.Done()
|
||||
io.Copy(dst, src)
|
||||
dst.Close()
|
||||
}
|
||||
rtl := func(dst io.Writer, src io.Reader) {
|
||||
defer wg.Done()
|
||||
copyBody(dst, src)
|
||||
}
|
||||
wg.Add(2)
|
||||
go ltr(right, leftreader)
|
||||
go rtl(leftwriter, right)
|
||||
groupdone := make(chan struct{}, 1)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
groupdone <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
leftreader.Close()
|
||||
right.Close()
|
||||
case <-groupdone:
|
||||
return
|
||||
}
|
||||
<-groupdone
|
||||
return
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Connection",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailers",
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func delHopHeaders(header http.Header) {
|
||||
for _, h := range hopHeaders {
|
||||
header.Del(h)
|
||||
}
|
||||
}
|
||||
|
||||
func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) {
|
||||
hj, ok := hijackable.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("Connection doesn't support hijacking")
|
||||
}
|
||||
conn, rw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var emptytime time.Time
|
||||
err = conn.SetDeadline(emptytime)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, rw, nil
|
||||
}
|
||||
|
||||
func flush(flusher interface{}) bool {
|
||||
f, ok := flusher.(http.Flusher)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
f.Flush()
|
||||
return true
|
||||
}
|
||||
|
||||
func copyBody(wr io.Writer, body io.Reader) {
|
||||
buf := make([]byte, COPY_BUF)
|
||||
for {
|
||||
bread, read_err := body.Read(buf)
|
||||
var write_err error
|
||||
if bread > 0 {
|
||||
_, write_err = wr.Write(buf[:bread])
|
||||
flush(wr)
|
||||
}
|
||||
if read_err != nil || write_err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterWallClock(d time.Duration) <-chan time.Time {
|
||||
ch := make(chan time.Time, 1)
|
||||
deadline := time.Now().Add(d).Truncate(0)
|
||||
after_ch := time.After(d)
|
||||
ticker := time.NewTicker(WALLCLOCK_PRECISION)
|
||||
go func() {
|
||||
var t time.Time
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case t = <-after_ch:
|
||||
ch <- t
|
||||
return
|
||||
case t = <-ticker.C:
|
||||
if t.After(deadline) {
|
||||
ch <- t
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func runTicker(ctx context.Context, interval, retryInterval time.Duration, cb func(context.Context) error) {
|
||||
go func() {
|
||||
var err error
|
||||
for {
|
||||
nextInterval := interval
|
||||
if err != nil {
|
||||
nextInterval = retryInterval
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-AfterWallClock(nextInterval):
|
||||
err = cb(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
Loading…
Add table
Reference in a new issue