mirror of
https://github.com/necronicle/z2k.git
synced 2026-04-28 03:20:25 +00:00
fix: rebuild all binaries with Go 1.22, remove legacy MTProxy code, add transparent WS reconnect
- Build all 9 arch binaries with Go 1.22.12 to fix MIPS crash (Go issue #71591) - Remove dead MTProxy/transparent mode code (relay.go, secret.go, transparent.go, dcmap.go) - Drop gotd/td dependency — only gorilla/websocket + stdlib remain - Tunnel is now the only mode, --tunnel flag removed - Transparent WS reconnect: keep client TCP alive during CF Worker WS drops - Re-CONNECT surviving streams after WS reconnects — seamless for clients - streamReadLoop waits for WS instead of dying on disconnect - New connections wait up to 5s for WS during reconnect instead of dropping - Drain connect semaphore on WS disconnect to prevent deadlock - Worker: MAX_STREAMS 200→100, improved logging Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
0b263e1c52
commit
e688fb1963
19 changed files with 162 additions and 1575 deletions
|
|
@ -131,6 +131,7 @@ export default {
|
|||
|
||||
// Track TCP streams: streamId → { socket, writer }
|
||||
const streams = new Map();
|
||||
const MAX_STREAMS = 100; // prevent memory exhaustion
|
||||
let authenticated = false;
|
||||
|
||||
// Pre-compute expected auth HMAC
|
||||
|
|
@ -264,7 +265,16 @@ export default {
|
|||
// Dispatch by message type
|
||||
switch (frame.msgType) {
|
||||
case MUX_CONNECT:
|
||||
handleConnect(frame.streamId, frame.payload);
|
||||
if (streams.size >= MAX_STREAMS) {
|
||||
console.warn(`stream ${frame.streamId} rejected: ${streams.size}/${MAX_STREAMS} streams active`);
|
||||
sendFrame(frame.streamId, MUX_CONNECT_FAIL, null);
|
||||
} else {
|
||||
// No await — handleConnect runs its own read loop async
|
||||
handleConnect(frame.streamId, frame.payload).catch(e => {
|
||||
console.error(`stream ${frame.streamId} unhandled error: ${e.message}`);
|
||||
closeStream(frame.streamId);
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case MUX_DATA: {
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -1,101 +0,0 @@
|
|||
package main
|
||||
|
||||
import "net"
|
||||
|
||||
// dcEntry maps a CIDR to a Telegram DC number.
|
||||
type dcEntry struct {
|
||||
cidr *net.IPNet
|
||||
dc int16
|
||||
}
|
||||
|
||||
var dcTable []dcEntry
|
||||
|
||||
func init() {
|
||||
// Telegram DC IP ranges -> DC number.
|
||||
// Negative DC = media DC (abs value is the DC).
|
||||
entries := []struct {
|
||||
cidr string
|
||||
dc int16
|
||||
}{
|
||||
// DC1
|
||||
{"149.154.175.0/24", 1},
|
||||
// DC2
|
||||
{"149.154.167.0/24", 2},
|
||||
{"95.161.76.0/24", 2},
|
||||
// DC3
|
||||
{"149.154.175.100/32", 3},
|
||||
// DC4
|
||||
{"149.154.167.91/32", 4},
|
||||
// DC5
|
||||
{"149.154.171.0/24", 5},
|
||||
{"91.108.56.0/22", 5},
|
||||
// General Telegram ranges (default to DC2)
|
||||
{"91.108.4.0/22", 2},
|
||||
{"91.108.8.0/22", 2},
|
||||
{"91.108.12.0/22", 2},
|
||||
{"91.108.16.0/22", 2},
|
||||
{"91.108.20.0/22", 2},
|
||||
{"149.154.160.0/20", 2},
|
||||
{"185.76.151.0/24", 2},
|
||||
{"91.105.192.0/23", 2},
|
||||
{"95.161.64.0/20", 2},
|
||||
|
||||
// IPv6 Telegram ranges
|
||||
{"2001:b28:f23d::/48", 2}, // DC2 main IPv6
|
||||
{"2001:b28:f23f::/48", 5}, // DC5 IPv6
|
||||
{"2001:67c:4e8::/48", 2}, // General Telegram IPv6
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
_, cidr, err := net.ParseCIDR(e.cidr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
dcTable = append(dcTable, dcEntry{cidr: cidr, dc: e.dc})
|
||||
}
|
||||
}
|
||||
|
||||
// LookupDC returns the Telegram DC number for the given IP.
|
||||
// Returns 2 (default DC) if no match found. Supports both IPv4 and IPv6.
|
||||
func LookupDC(ip net.IP) int16 {
|
||||
isV6 := ip.To4() == nil
|
||||
|
||||
if !isV6 {
|
||||
// IPv4: check most specific first (single IPs for DC3/DC4).
|
||||
for _, e := range dcTable {
|
||||
ones, _ := e.cidr.Mask.Size()
|
||||
if ones == 32 && e.cidr.IP.Equal(ip) {
|
||||
return e.dc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check all CIDR ranges (both v4 and v6).
|
||||
// For IPv4, skip /32 (already checked above). For IPv6, check all.
|
||||
bestOnes := -1
|
||||
bestDC := int16(0)
|
||||
for _, e := range dcTable {
|
||||
ones, bits := e.cidr.Mask.Size()
|
||||
if !isV6 && ones == 32 {
|
||||
continue // already handled above for v4
|
||||
}
|
||||
if e.cidr.Contains(ip) {
|
||||
// Pick the most specific (longest prefix) match.
|
||||
// Normalize: compare relative specificity (ones out of bits).
|
||||
// For same address family, just compare ones directly.
|
||||
specificity := ones
|
||||
if bits == 128 {
|
||||
// IPv6 range
|
||||
specificity = ones
|
||||
}
|
||||
if specificity > bestOnes {
|
||||
bestOnes = specificity
|
||||
bestDC = e.dc
|
||||
}
|
||||
}
|
||||
}
|
||||
if bestOnes >= 0 {
|
||||
return bestDC
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
|
@ -1,20 +1,5 @@
|
|||
module github.com/necronicle/z2k/mtproxy-client
|
||||
|
||||
go 1.26.1
|
||||
go 1.22
|
||||
|
||||
require (
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/gotd/td v0.143.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/go-faster/errors v0.7.1 // indirect
|
||||
github.com/go-faster/jx v1.2.0 // indirect
|
||||
github.com/go-faster/xor v1.0.0 // indirect
|
||||
github.com/gotd/ige v0.2.2 // indirect
|
||||
github.com/gotd/neo v0.1.5 // indirect
|
||||
github.com/segmentio/asm v1.2.1 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.uber.org/zap v1.27.1 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
)
|
||||
require github.com/gorilla/websocket v1.5.3
|
||||
|
|
|
|||
|
|
@ -1,72 +1,2 @@
|
|||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
|
||||
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||
github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg=
|
||||
github.com/go-faster/errors v0.7.1/go.mod h1:5ySTjWFiphBs07IKuiL69nxdfd5+fzh1u7FPGZP2quo=
|
||||
github.com/go-faster/jx v1.2.0 h1:T2YHJPrFaYu21fJtUxC9GzmluKu8rVIFDwwGBKTDseI=
|
||||
github.com/go-faster/jx v1.2.0/go.mod h1:UWLOVDmMG597a5tBFPLIWJdUxz5/2emOpfsj9Neg0PE=
|
||||
github.com/go-faster/xor v0.3.0/go.mod h1:x5CaDY9UKErKzqfRfFZdfu+OSTfoZny3w5Ak7UxcipQ=
|
||||
github.com/go-faster/xor v1.0.0 h1:2o8vTOgErSGHP3/7XwA5ib1FTtUsNtwCoLLBjl31X38=
|
||||
github.com/go-faster/xor v1.0.0/go.mod h1:x5CaDY9UKErKzqfRfFZdfu+OSTfoZny3w5Ak7UxcipQ=
|
||||
github.com/go-faster/yaml v0.4.6 h1:lOK/EhI04gCpPgPhgt0bChS6bvw7G3WwI8xxVe0sw9I=
|
||||
github.com/go-faster/yaml v0.4.6/go.mod h1:390dRIvV4zbnO7qC9FGo6YYutc+wyyUSHBgbXL52eXk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gotd/ige v0.2.2 h1:XQ9dJZwBfDnOGSTxKXBGP4gMud3Qku2ekScRjDWWfEk=
|
||||
github.com/gotd/ige v0.2.2/go.mod h1:tuCRb+Y5Y3eNTo3ypIfNpQ4MFjrnONiL2jN2AKZXmb0=
|
||||
github.com/gotd/neo v0.1.5 h1:oj0iQfMbGClP8xI59x7fE/uHoTJD7NZH9oV1WNuPukQ=
|
||||
github.com/gotd/neo v0.1.5/go.mod h1:9A2a4bn9zL6FADufBdt7tZt+WMhvZoc5gWXihOPoiBQ=
|
||||
github.com/gotd/td v0.143.0 h1:p0U/Nn92zXmAsahDn5CIVzay2kQ36lBBENT/FlWR2nQ=
|
||||
github.com/gotd/td v0.143.0/go.mod h1:8GA5ecTI5iswLwBAlqf0u6/+j+BqSWUARSrX2Xk1usQ=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/ogen-go/ogen v1.19.0 h1:YvdNpeQJ8A8dLLpS6Vs4WxXL53BT6tBPxH0VSjfALhA=
|
||||
github.com/ogen-go/ogen v1.19.0/go.mod h1:DeShwO+TEpLYXNCuZliSAedphphXsJaTGGbmSomWUjE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0=
|
||||
github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
|
||||
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
|
||||
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
|
||||
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
||||
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
|
|||
|
|
@ -1,23 +1,9 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
|
@ -25,9 +11,6 @@ import (
|
|||
|
||||
var (
|
||||
listenAddr = flag.String("listen", ":1443", "Local listen address")
|
||||
secretHex = flag.String("secret", "", "Proxy secret (dd-prefixed hex, auto-generated if empty)")
|
||||
transparent = flag.Bool("transparent", false, "Transparent mode: redirect Telegram DC traffic via iptables (no client config needed)")
|
||||
tunnelMode = flag.Bool("tunnel", false, "Tunnel mode: multiplex TCP over single WS to Cloudflare Worker")
|
||||
tunnelURL = flag.String("tunnel-url", "wss://z2k-tunnel.necronicle.workers.dev/ws", "Cloudflare Worker WebSocket URL")
|
||||
tunnelSecret = flag.String("tunnel-secret", "d01f72f9543b29da4e3724b1530c0d11cb30a6f8db15bc0adfe8f2d37b5844b2", "Shared secret for tunnel auth")
|
||||
verbose = flag.Bool("v", false, "Verbose logging")
|
||||
|
|
@ -35,244 +18,9 @@ var (
|
|||
maxConns = flag.Int("max-conns", 1024, "Maximum concurrent connections")
|
||||
)
|
||||
|
||||
const handshakeLen = 64
|
||||
|
||||
// connSemaphore limits concurrent connections
|
||||
var connSemaphore chan struct{}
|
||||
|
||||
// DC WebSocket domains
|
||||
func wsDomains(dc int, isMedia bool) []string {
|
||||
if dc == 203 {
|
||||
dc = 2
|
||||
}
|
||||
if isMedia {
|
||||
return []string{
|
||||
fmt.Sprintf("kws%d-1.web.telegram.org", dc),
|
||||
fmt.Sprintf("kws%d.web.telegram.org", dc),
|
||||
}
|
||||
}
|
||||
return []string{
|
||||
fmt.Sprintf("kws%d.web.telegram.org", dc),
|
||||
fmt.Sprintf("kws%d-1.web.telegram.org", dc),
|
||||
}
|
||||
}
|
||||
|
||||
// tryHandshake decrypts client's obfuscated2 header using proxy secret.
|
||||
// Returns DC id, isMedia flag, protocol tag, and AES key material.
|
||||
func tryHandshake(header []byte, secret []byte) (dc int, isMedia bool, protoTag uint32, decKey, decIV, encKey, encIV []byte, err error) {
|
||||
if len(header) != handshakeLen {
|
||||
return 0, false, 0, nil, nil, nil, nil, fmt.Errorf("header len %d", len(header))
|
||||
}
|
||||
|
||||
// Decrypt direction: client→proxy uses header[8:40] as key, header[40:56] as IV
|
||||
rawKey := make([]byte, 32)
|
||||
copy(rawKey, header[8:40])
|
||||
rawIV := make([]byte, 16)
|
||||
copy(rawIV, header[40:56])
|
||||
|
||||
// Mix with secret: key = SHA256(rawKey || secret)
|
||||
h := sha256.New()
|
||||
h.Write(rawKey)
|
||||
h.Write(secret)
|
||||
decKey = h.Sum(nil)
|
||||
decIV = rawIV
|
||||
|
||||
// Decrypt entire header to read protocol tag and DC
|
||||
block, cipherErr := aes.NewCipher(decKey)
|
||||
if cipherErr != nil {
|
||||
return 0, false, 0, nil, nil, nil, nil, fmt.Errorf("aes.NewCipher(decKey): %w", cipherErr)
|
||||
}
|
||||
stream := cipher.NewCTR(block, decIV)
|
||||
decrypted := make([]byte, handshakeLen)
|
||||
stream.XORKeyStream(decrypted, header)
|
||||
|
||||
protoTag = binary.LittleEndian.Uint32(decrypted[56:60])
|
||||
// Validate protocol tag
|
||||
if protoTag != 0xefefefef && protoTag != 0xeeeeeeee && protoTag != 0xdddddddd {
|
||||
return 0, false, 0, nil, nil, nil, nil, fmt.Errorf("bad proto tag 0x%08x", protoTag)
|
||||
}
|
||||
|
||||
dcIdx := int16(binary.LittleEndian.Uint16(decrypted[60:62]))
|
||||
dc = int(dcIdx)
|
||||
if dc < 0 {
|
||||
dc = -dc
|
||||
isMedia = true
|
||||
}
|
||||
if dc == 0 {
|
||||
dc = 2
|
||||
}
|
||||
|
||||
// Encrypt direction (proxy→client): reversed header[8:56]
|
||||
reversed := make([]byte, 48)
|
||||
copy(reversed, header[8:56])
|
||||
for i, j := 0, len(reversed)-1; i < j; i, j = i+1, j-1 {
|
||||
reversed[i], reversed[j] = reversed[j], reversed[i]
|
||||
}
|
||||
h2 := sha256.New()
|
||||
h2.Write(reversed[:32])
|
||||
h2.Write(secret)
|
||||
encKey = h2.Sum(nil)
|
||||
encIV = reversed[32:48]
|
||||
|
||||
return dc, isMedia, protoTag, decKey, decIV, encKey, encIV, nil
|
||||
}
|
||||
|
||||
// generateRelayInit creates a new obfuscated2 header for connecting to Telegram DC
|
||||
// (without proxy secret — direct DC connection).
|
||||
func generateRelayInit(protoTag uint32, dcIdx int) (header []byte, relayEncKey, relayEncIV, relayDecKey, relayDecIV []byte, err error) {
|
||||
header = make([]byte, handshakeLen)
|
||||
for {
|
||||
if _, err := io.ReadFull(rand.Reader, header); err != nil {
|
||||
return nil, nil, nil, nil, nil, err
|
||||
}
|
||||
if header[0] == 0xef {
|
||||
continue
|
||||
}
|
||||
first4 := binary.LittleEndian.Uint32(header[0:4])
|
||||
if first4 == 0x44414548 || first4 == 0x54534f50 || first4 == 0x20544547 ||
|
||||
first4 == 0x4954504f || first4 == 0x02010316 ||
|
||||
first4 == 0xdddddddd || first4 == 0xeeeeeeee {
|
||||
continue
|
||||
}
|
||||
if header[4]|header[5]|header[6]|header[7] == 0 {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Encryption key for relay→DC (our writes to DC)
|
||||
relayEncKey = make([]byte, 32)
|
||||
copy(relayEncKey, header[8:40])
|
||||
relayEncIV = make([]byte, 16)
|
||||
copy(relayEncIV, header[40:56])
|
||||
|
||||
// Decryption key for DC→relay (reads from DC): reversed
|
||||
reversed := make([]byte, 48)
|
||||
copy(reversed, header[8:56])
|
||||
for i, j := 0, len(reversed)-1; i < j; i, j = i+1, j-1 {
|
||||
reversed[i], reversed[j] = reversed[j], reversed[i]
|
||||
}
|
||||
relayDecKey = reversed[:32]
|
||||
relayDecIV = reversed[32:48]
|
||||
|
||||
// Write protocol tag and DC, encrypt with AES-CTR
|
||||
block, cipherErr := aes.NewCipher(relayEncKey)
|
||||
if cipherErr != nil {
|
||||
return nil, nil, nil, nil, nil, fmt.Errorf("aes.NewCipher(relayEncKey): %w", cipherErr)
|
||||
}
|
||||
encStream := cipher.NewCTR(block, relayEncIV)
|
||||
encrypted := make([]byte, handshakeLen)
|
||||
encStream.XORKeyStream(encrypted, header)
|
||||
|
||||
// Build tail: protocol_tag + dc_bytes + 2 random bytes
|
||||
tail := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint32(tail[0:4], protoTag)
|
||||
binary.LittleEndian.PutUint16(tail[4:6], uint16(int16(dcIdx)))
|
||||
if _, err := rand.Read(tail[6:8]); err != nil {
|
||||
return nil, nil, nil, nil, nil, fmt.Errorf("rand.Read: %w", err)
|
||||
}
|
||||
|
||||
// XOR tail with keystream at position 56
|
||||
for i := 0; i < 8; i++ {
|
||||
tail[i] ^= encrypted[56+i] ^ header[56+i]
|
||||
}
|
||||
copy(header[56:64], tail)
|
||||
|
||||
return header, relayEncKey, relayEncIV, relayDecKey, relayDecIV, nil
|
||||
}
|
||||
|
||||
// resolveIP resolves a hostname, preferring IPv4 but falling back to IPv6.
|
||||
func resolveIP(host string) (string, error) {
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Prefer IPv4
|
||||
for _, ip := range ips {
|
||||
if ip.To4() != nil {
|
||||
return ip.String(), nil
|
||||
}
|
||||
}
|
||||
// Fall back to IPv6
|
||||
for _, ip := range ips {
|
||||
if ip.To4() == nil && ip.To16() != nil {
|
||||
return ip.String(), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no IP for %s", host)
|
||||
}
|
||||
|
||||
// resolveIPv4 resolves only IPv4 addresses (kept for backward compatibility).
|
||||
func resolveIPv4(host string) (string, error) {
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if ip.To4() != nil {
|
||||
return ip.String(), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no IPv4 for %s", host)
|
||||
}
|
||||
|
||||
func connectWS(dc int, isMedia bool) (*websocket.Conn, error) {
|
||||
domains := wsDomains(dc, isMedia)
|
||||
|
||||
// Try direct WS first, then Cloudflare proxy fallback
|
||||
allDomains := make([]string, 0, len(domains)+1)
|
||||
allDomains = append(allDomains, domains...)
|
||||
allDomains = append(allDomains, fmt.Sprintf("kws%d.pclead.co.uk", dc))
|
||||
|
||||
for _, domain := range allDomains {
|
||||
ip, err := resolveIP(domain)
|
||||
if err != nil {
|
||||
if *verbose {
|
||||
log.Printf("[debug] resolve %s failed: %v", domain, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine dial network and address format based on IP version
|
||||
dialNetwork := "tcp4"
|
||||
dialAddr := ip + ":443"
|
||||
if net.ParseIP(ip).To4() == nil {
|
||||
dialNetwork = "tcp6"
|
||||
dialAddr = "[" + ip + "]:443"
|
||||
}
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
TLSClientConfig: &tls.Config{
|
||||
ServerName: domain,
|
||||
},
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
Subprotocols: []string{"binary"},
|
||||
NetDial: func(network, addr string) (net.Conn, error) {
|
||||
return net.DialTimeout(dialNetwork, dialAddr, 5*time.Second)
|
||||
},
|
||||
}
|
||||
headers := http.Header{}
|
||||
headers.Set("Origin", "http://web.telegram.org")
|
||||
headers.Set("Host", domain)
|
||||
|
||||
url := fmt.Sprintf("wss://%s/apiws", domain)
|
||||
ws, _, err := dialer.Dial(url, headers)
|
||||
if err != nil {
|
||||
if *verbose {
|
||||
log.Printf("[debug] WS dial %s (%s) failed: %v", domain, ip, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Set read limit to prevent memory exhaustion
|
||||
ws.SetReadLimit(1 * 1024 * 1024) // 1MB max message
|
||||
if *verbose {
|
||||
log.Printf("[debug] WS connected to %s (%s)", domain, ip)
|
||||
}
|
||||
return ws, nil
|
||||
}
|
||||
return nil, fmt.Errorf("all WS domains failed for DC%d", dc)
|
||||
}
|
||||
|
||||
// wsWriter serializes all writes to a WebSocket connection.
|
||||
// gorilla/websocket supports only one concurrent writer.
|
||||
type wsWriter struct {
|
||||
|
|
@ -288,298 +36,15 @@ func (w *wsWriter) WriteMessage(messageType int, data []byte) error {
|
|||
}
|
||||
|
||||
func (w *wsWriter) WriteControl(messageType int, data []byte, deadline time.Time) error {
|
||||
// WriteControl is documented as thread-safe in gorilla/websocket,
|
||||
// but we serialize anyway for safety
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.ws.WriteControl(messageType, data, deadline)
|
||||
}
|
||||
|
||||
func handleConnection(ctx context.Context, clientConn *net.TCPConn, secret []byte) {
|
||||
defer clientConn.Close()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[panic] %s: %v", clientConn.RemoteAddr(), r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Set initial deadline for handshake
|
||||
clientConn.SetDeadline(time.Now().Add(10 * time.Second))
|
||||
|
||||
// Read client obfuscated2 header
|
||||
header := make([]byte, handshakeLen)
|
||||
if _, err := io.ReadFull(clientConn, header); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt with proxy secret
|
||||
dc, isMedia, protoTag, cltDecKey, cltDecIV, cltEncKey, cltEncIV, err := tryHandshake(header, secret)
|
||||
if err != nil {
|
||||
if *verbose {
|
||||
log.Printf("[error] handshake: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
mediaTag := ""
|
||||
if isMedia {
|
||||
mediaTag = "m"
|
||||
}
|
||||
dcIdx := dc
|
||||
if isMedia {
|
||||
dcIdx = -dc
|
||||
}
|
||||
|
||||
if *verbose {
|
||||
log.Printf("[conn] %s DC%d%s proto=0x%08x", clientConn.RemoteAddr(), dc, mediaTag, protoTag)
|
||||
}
|
||||
|
||||
// Generate relay header for Telegram DC (no secret)
|
||||
relayInit, relayEncKey, relayEncIV, relayDecKey, relayDecIV, err := generateRelayInit(protoTag, dcIdx)
|
||||
if err != nil {
|
||||
log.Printf("[error] relay init: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Connect via WebSocket to Telegram DC
|
||||
ws, err := connectWS(dc, isMedia)
|
||||
if err != nil {
|
||||
log.Printf("[error] WS connect DC%d%s: %v", dc, mediaTag, err)
|
||||
return
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
writer := &wsWriter{ws: ws}
|
||||
|
||||
// Send relay init header as first WS message
|
||||
if err := writer.WriteMessage(websocket.BinaryMessage, relayInit); err != nil {
|
||||
log.Printf("[error] WS write init: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create AES-CTR streams
|
||||
cltDecBlock, err := aes.NewCipher(cltDecKey)
|
||||
if err != nil {
|
||||
log.Printf("[error] aes.NewCipher(cltDecKey): %v", err)
|
||||
return
|
||||
}
|
||||
cltDecStream := cipher.NewCTR(cltDecBlock, cltDecIV)
|
||||
// Advance past the 64-byte header
|
||||
skip := make([]byte, handshakeLen)
|
||||
cltDecStream.XORKeyStream(skip, skip)
|
||||
|
||||
cltEncBlock, err := aes.NewCipher(cltEncKey)
|
||||
if err != nil {
|
||||
log.Printf("[error] aes.NewCipher(cltEncKey): %v", err)
|
||||
return
|
||||
}
|
||||
cltEncStream := cipher.NewCTR(cltEncBlock, cltEncIV)
|
||||
|
||||
relayEncBlock, err := aes.NewCipher(relayEncKey)
|
||||
if err != nil {
|
||||
log.Printf("[error] aes.NewCipher(relayEncKey): %v", err)
|
||||
return
|
||||
}
|
||||
relayEncStream := cipher.NewCTR(relayEncBlock, relayEncIV)
|
||||
relayEncStream.XORKeyStream(make([]byte, handshakeLen), make([]byte, handshakeLen))
|
||||
|
||||
relayDecBlock, err := aes.NewCipher(relayDecKey)
|
||||
if err != nil {
|
||||
log.Printf("[error] aes.NewCipher(relayDecKey): %v", err)
|
||||
return
|
||||
}
|
||||
relayDecStream := cipher.NewCTR(relayDecBlock, relayDecIV)
|
||||
|
||||
if *verbose {
|
||||
log.Printf("[relay] %s <-> WS DC%d%s", clientConn.RemoteAddr(), dc, mediaTag)
|
||||
}
|
||||
|
||||
// Reset deadline for data transfer
|
||||
clientConn.SetDeadline(time.Now().Add(*connTimeout))
|
||||
|
||||
// Create cancellable context for this connection
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
var upBytes, downBytes int64
|
||||
|
||||
// client → WS
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
buf := make([]byte, 65536)
|
||||
for {
|
||||
select {
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
n, err := clientConn.Read(buf)
|
||||
if n > 0 {
|
||||
// Reset deadline on activity
|
||||
clientConn.SetDeadline(time.Now().Add(*connTimeout))
|
||||
plain := make([]byte, n)
|
||||
cltDecStream.XORKeyStream(plain, buf[:n])
|
||||
encrypted := make([]byte, n)
|
||||
relayEncStream.XORKeyStream(encrypted, plain)
|
||||
if werr := writer.WriteMessage(websocket.BinaryMessage, encrypted); werr != nil {
|
||||
break
|
||||
}
|
||||
upBytes += int64(n)
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// WS → client
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
for {
|
||||
select {
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
_, msg, rerr := ws.ReadMessage()
|
||||
if rerr != nil {
|
||||
if *verbose && websocket.IsUnexpectedCloseError(rerr, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
|
||||
log.Printf("[debug] WS read error: %v", rerr)
|
||||
}
|
||||
break
|
||||
}
|
||||
if len(msg) > 0 {
|
||||
// Reset deadline on activity
|
||||
clientConn.SetDeadline(time.Now().Add(*connTimeout))
|
||||
plain := make([]byte, len(msg))
|
||||
relayDecStream.XORKeyStream(plain, msg)
|
||||
encrypted := make([]byte, len(msg))
|
||||
cltEncStream.XORKeyStream(encrypted, plain)
|
||||
if _, werr := clientConn.Write(encrypted); werr != nil {
|
||||
break
|
||||
}
|
||||
downBytes += int64(len(msg))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if *verbose {
|
||||
log.Printf("[done] %s DC%d%s up=%d down=%d", clientConn.RemoteAddr(), dc, mediaTag, upBytes, downBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if *tunnelMode {
|
||||
// Tunnel mode: multiplex TCP over single WS to Cloudflare Worker
|
||||
if err := runTunnel(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if *transparent {
|
||||
// Transparent mode: iptables REDIRECT, no client config needed
|
||||
if err := transparentListener(*listenAddr); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MTProxy mode: requires client configuration
|
||||
var secret []byte
|
||||
if *secretHex == "" {
|
||||
secret = make([]byte, 16)
|
||||
if _, err := rand.Read(secret); err != nil {
|
||||
log.Fatalf("Failed to generate secret: %v", err)
|
||||
}
|
||||
*secretHex = fmt.Sprintf("dd%x", secret)
|
||||
log.Printf("Generated secret: %s", *secretHex)
|
||||
} else {
|
||||
parsed, err := parseSecretHex(*secretHex)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid secret: %v", err)
|
||||
}
|
||||
secret = parsed
|
||||
}
|
||||
|
||||
// Initialize connection limiter
|
||||
connSemaphore = make(chan struct{}, *maxConns)
|
||||
|
||||
ln, err := net.Listen("tcp", *listenAddr)
|
||||
if err != nil {
|
||||
log.Fatalf("Listen %s: %v", *listenAddr, err)
|
||||
}
|
||||
|
||||
// Parse host and port safely
|
||||
host := "ROUTER_IP"
|
||||
_, port, splitErr := net.SplitHostPort(*listenAddr)
|
||||
if splitErr != nil {
|
||||
port = *listenAddr
|
||||
}
|
||||
log.Printf("tg-ws-proxy listening on %s", *listenAddr)
|
||||
log.Printf("Add proxy in Telegram: tg://proxy?server=%s&port=%s&secret=%s", host, port, *secretHex)
|
||||
|
||||
// Graceful shutdown
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
log.Println("[shutdown] Closing listener...")
|
||||
ln.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Println("[shutdown] Server stopped")
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// Rate limit connections
|
||||
select {
|
||||
case connSemaphore <- struct{}{}:
|
||||
go func() {
|
||||
defer func() { <-connSemaphore }()
|
||||
handleConnection(ctx, tcpConn, secret)
|
||||
}()
|
||||
default:
|
||||
if *verbose {
|
||||
log.Printf("[warn] max connections reached, rejecting %s", conn.RemoteAddr())
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
if err := runTunnel(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func parseSecretHex(s string) ([]byte, error) {
|
||||
if len(s) < 34 || (s[:2] != "dd" && s[:2] != "ee") {
|
||||
return nil, fmt.Errorf("secret must start with dd or ee and be at least 34 hex chars")
|
||||
}
|
||||
raw := make([]byte, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
_, err := fmt.Sscanf(s[2+i*2:4+i*2], "%02x", &raw[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,353 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTryHandshake_ValidHeader(t *testing.T) {
|
||||
secret := make([]byte, 16)
|
||||
rand.Read(secret)
|
||||
|
||||
// Generate a valid relay init (which also creates a valid header)
|
||||
header, _, _, _, _, err := generateRelayInit(0xefefefef, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("generateRelayInit failed: %v", err)
|
||||
}
|
||||
|
||||
// tryHandshake expects a header encrypted WITH secret, so this won't
|
||||
// match the proto tag. That's fine — we test that it returns an error
|
||||
// for bad proto tag (the header was not encrypted with the secret).
|
||||
_, _, _, _, _, _, _, herr := tryHandshake(header, secret)
|
||||
if herr == nil {
|
||||
t.Log("Handshake succeeded (unexpected but not necessarily wrong)")
|
||||
}
|
||||
// The main thing is it doesn't panic
|
||||
}
|
||||
|
||||
func TestTryHandshake_WrongLength(t *testing.T) {
|
||||
secret := make([]byte, 16)
|
||||
rand.Read(secret)
|
||||
|
||||
_, _, _, _, _, _, _, err := tryHandshake([]byte("short"), secret)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for short header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRelayInit_NoCollision(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
header, encKey, encIV, decKey, decIV, err := generateRelayInit(0xeeeeeeee, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("generateRelayInit failed: %v", err)
|
||||
}
|
||||
if len(header) != handshakeLen {
|
||||
t.Fatalf("header len = %d, want %d", len(header), handshakeLen)
|
||||
}
|
||||
if len(encKey) != 32 || len(encIV) != 16 || len(decKey) != 32 || len(decIV) != 16 {
|
||||
t.Fatalf("key/iv lengths wrong")
|
||||
}
|
||||
// Verify header[0] != 0xef (excluded)
|
||||
if header[0] == 0xef {
|
||||
t.Fatal("header[0] should never be 0xef")
|
||||
}
|
||||
// Verify first4 is not a forbidden value
|
||||
first4 := binary.LittleEndian.Uint32(header[0:4])
|
||||
forbidden := []uint32{0x44414548, 0x54534f50, 0x20544547, 0x4954504f, 0x02010316, 0xdddddddd, 0xeeeeeeee}
|
||||
for _, f := range forbidden {
|
||||
if first4 == f {
|
||||
t.Fatalf("header first4 = 0x%08x (forbidden)", first4)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSecretHex_Valid(t *testing.T) {
|
||||
// dd + 32 hex chars = valid
|
||||
hex := "dd0123456789abcdef0123456789abcdef"
|
||||
secret, err := parseSecretHex(hex)
|
||||
if err != nil {
|
||||
t.Fatalf("parseSecretHex failed: %v", err)
|
||||
}
|
||||
if len(secret) != 16 {
|
||||
t.Fatalf("secret len = %d, want 16", len(secret))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSecretHex_TooShort(t *testing.T) {
|
||||
_, err := parseSecretHex("dd01234567")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for short secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSecretHex_BadPrefix(t *testing.T) {
|
||||
_, err := parseSecretHex("aa0123456789abcdef0123456789abcdef")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for bad prefix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupDC_KnownRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected int16
|
||||
}{
|
||||
{"149.154.175.1", 1}, // DC1
|
||||
{"149.154.167.50", 2}, // DC2
|
||||
{"149.154.175.100", 3}, // DC3 (specific IP)
|
||||
{"149.154.167.91", 4}, // DC4 (specific IP)
|
||||
{"149.154.171.10", 5}, // DC5
|
||||
{"91.108.56.1", 5}, // DC5
|
||||
{"8.8.8.8", 2}, // Unknown → default DC2
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
got := LookupDC(ip)
|
||||
if got != tt.expected {
|
||||
t.Errorf("LookupDC(%s) = %d, want %d", tt.ip, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupDC_Specificity(t *testing.T) {
|
||||
// DC3 is 149.154.175.100/32, DC1 is 149.154.175.0/24
|
||||
// DC3 should win for exact IP match
|
||||
ip := net.ParseIP("149.154.175.100")
|
||||
got := LookupDC(ip)
|
||||
if got != 3 {
|
||||
t.Errorf("LookupDC(149.154.175.100) = %d, want 3", got)
|
||||
}
|
||||
|
||||
// But a different IP in the /24 should be DC1
|
||||
ip2 := net.ParseIP("149.154.175.99")
|
||||
got2 := LookupDC(ip2)
|
||||
if got2 != 1 {
|
||||
t.Errorf("LookupDC(149.154.175.99) = %d, want 1", got2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWsDomains(t *testing.T) {
|
||||
// Non-media: primary domain first
|
||||
domains := wsDomains(2, false)
|
||||
if len(domains) != 2 {
|
||||
t.Fatalf("expected 2 domains, got %d", len(domains))
|
||||
}
|
||||
if domains[0] != "kws2.web.telegram.org" {
|
||||
t.Errorf("domains[0] = %s, want kws2.web.telegram.org", domains[0])
|
||||
}
|
||||
|
||||
// Media: -1 domain first
|
||||
mediaDomains := wsDomains(2, true)
|
||||
if mediaDomains[0] != "kws2-1.web.telegram.org" {
|
||||
t.Errorf("mediaDomains[0] = %s, want kws2-1.web.telegram.org", mediaDomains[0])
|
||||
}
|
||||
|
||||
// DC 203 maps to DC 2
|
||||
dc203 := wsDomains(203, false)
|
||||
if dc203[0] != "kws2.web.telegram.org" {
|
||||
t.Errorf("DC203 domains[0] = %s, want kws2.web.telegram.org", dc203[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupDC_IPv6(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected int16
|
||||
}{
|
||||
{"2001:b28:f23d::1", 2}, // DC2 main IPv6 range
|
||||
{"2001:b28:f23d:f:1::2", 2}, // DC2 within /48
|
||||
{"2001:b28:f23f::1", 5}, // DC5 IPv6 range
|
||||
{"2001:b28:f23f:a:b::c", 5}, // DC5 within /48
|
||||
{"2001:67c:4e8::1", 2}, // General Telegram IPv6
|
||||
{"2001:67c:4e8:ff::1", 2}, // General Telegram IPv6 within /48
|
||||
{"2600:1234::1", 2}, // Unknown IPv6 → default DC2
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("failed to parse IP %s", tt.ip)
|
||||
}
|
||||
got := LookupDC(ip)
|
||||
if got != tt.expected {
|
||||
t.Errorf("LookupDC(%s) = %d, want %d", tt.ip, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOriginalDst_SkipWithoutIptables(t *testing.T) {
|
||||
t.Skip("getOriginalDst requires iptables REDIRECT; skipping in unit tests")
|
||||
}
|
||||
|
||||
func TestWsWriter_Serialization(t *testing.T) {
|
||||
// Just verify the struct compiles and methods exist
|
||||
var _ *wsWriter
|
||||
// Full test would require a mock WebSocket, skipping
|
||||
}
|
||||
|
||||
func TestMuxEncodeDecode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
streamID uint16
|
||||
msgType byte
|
||||
payload []byte
|
||||
}{
|
||||
{"CONNECT empty", 1, muxCONNECT, []byte{addrIPv4, 149, 154, 175, 1, 0x01, 0xBB}},
|
||||
{"DATA small", 42, muxDATA, []byte("hello world")},
|
||||
{"DATA large", 65535, muxDATA, make([]byte, 64*1024)},
|
||||
{"CLOSE no payload", 100, muxCLOSE, nil},
|
||||
{"CONNECT_OK", 7, muxCONNECT_OK, nil},
|
||||
{"CONNECT_FAIL", 8, muxCONNECT_FAIL, nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
encoded := encodeMuxFrame(tt.streamID, tt.msgType, tt.payload)
|
||||
|
||||
// Verify minimum length
|
||||
if len(encoded) < 3 {
|
||||
t.Fatalf("encoded frame too short: %d", len(encoded))
|
||||
}
|
||||
|
||||
decoded, err := decodeMuxFrame(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeMuxFrame failed: %v", err)
|
||||
}
|
||||
|
||||
if decoded.StreamID != tt.streamID {
|
||||
t.Errorf("StreamID = %d, want %d", decoded.StreamID, tt.streamID)
|
||||
}
|
||||
if decoded.MsgType != tt.msgType {
|
||||
t.Errorf("MsgType = 0x%02x, want 0x%02x", decoded.MsgType, tt.msgType)
|
||||
}
|
||||
|
||||
// Compare payloads
|
||||
if tt.payload == nil {
|
||||
if len(decoded.Payload) != 0 {
|
||||
t.Errorf("Payload len = %d, want 0", len(decoded.Payload))
|
||||
}
|
||||
} else {
|
||||
if len(decoded.Payload) != len(tt.payload) {
|
||||
t.Fatalf("Payload len = %d, want %d", len(decoded.Payload), len(tt.payload))
|
||||
}
|
||||
for i := range tt.payload {
|
||||
if decoded.Payload[i] != tt.payload[i] {
|
||||
t.Errorf("Payload[%d] = 0x%02x, want 0x%02x", i, decoded.Payload[i], tt.payload[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMuxDecodeFrameTooShort(t *testing.T) {
|
||||
_, err := decodeMuxFrame([]byte{0x00})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for short frame")
|
||||
}
|
||||
_, err = decodeMuxFrame([]byte{0x00, 0x01})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 2-byte frame")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectPayloadIPv4(t *testing.T) {
|
||||
ip := net.ParseIP("149.154.175.1")
|
||||
port := 443
|
||||
payload := encodeConnectPayload(ip, port)
|
||||
|
||||
gotIP, gotPort, err := decodeConnectPayload(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeConnectPayload: %v", err)
|
||||
}
|
||||
if !gotIP.Equal(ip.To4()) {
|
||||
t.Errorf("IP = %s, want %s", gotIP, ip)
|
||||
}
|
||||
if gotPort != port {
|
||||
t.Errorf("Port = %d, want %d", gotPort, port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectPayloadIPv6(t *testing.T) {
|
||||
ip := net.ParseIP("2001:b28:f23d::1")
|
||||
port := 443
|
||||
payload := encodeConnectPayload(ip, port)
|
||||
|
||||
gotIP, gotPort, err := decodeConnectPayload(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeConnectPayload: %v", err)
|
||||
}
|
||||
if !gotIP.Equal(ip) {
|
||||
t.Errorf("IP = %s, want %s", gotIP, ip)
|
||||
}
|
||||
if gotPort != port {
|
||||
t.Errorf("Port = %d, want %d", gotPort, port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectPayloadRoundtrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
port int
|
||||
}{
|
||||
{"149.154.167.91", 443},
|
||||
{"91.108.56.1", 8443},
|
||||
{"10.0.0.1", 1},
|
||||
{"255.255.255.255", 65535},
|
||||
{"2001:b28:f23f::1", 443},
|
||||
{"::1", 80},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
payload := encodeConnectPayload(ip, tt.port)
|
||||
gotIP, gotPort, err := decodeConnectPayload(payload)
|
||||
if err != nil {
|
||||
t.Errorf("decodeConnectPayload(%s:%d): %v", tt.ip, tt.port, err)
|
||||
continue
|
||||
}
|
||||
// Normalize for comparison
|
||||
if ip.To4() != nil {
|
||||
ip = ip.To4()
|
||||
}
|
||||
if !gotIP.Equal(ip) {
|
||||
t.Errorf("IP = %s, want %s", gotIP, ip)
|
||||
}
|
||||
if gotPort != tt.port {
|
||||
t.Errorf("Port = %d, want %d", gotPort, tt.port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeAuthHMAC(t *testing.T) {
|
||||
mac1 := computeAuthHMAC("test-secret")
|
||||
mac2 := computeAuthHMAC("test-secret")
|
||||
mac3 := computeAuthHMAC("different-secret")
|
||||
|
||||
if len(mac1) != 32 {
|
||||
t.Fatalf("HMAC length = %d, want 32", len(mac1))
|
||||
}
|
||||
|
||||
// Same secret should produce same HMAC
|
||||
for i := range mac1 {
|
||||
if mac1[i] != mac2[i] {
|
||||
t.Fatal("same secret produced different HMACs")
|
||||
}
|
||||
}
|
||||
|
||||
// Different secret should produce different HMAC
|
||||
same := true
|
||||
for i := range mac1 {
|
||||
if mac1[i] != mac3[i] {
|
||||
same = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if same {
|
||||
t.Fatal("different secrets produced same HMAC")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/gotd/td/mtproxy/obfuscator"
|
||||
)
|
||||
|
||||
// relay bidirectionally copies data between a TCP client and an obfuscated MTProxy connection.
|
||||
func relay(client net.Conn, server *obfuscator.Conn) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// client → server (raw MTProto → encrypted)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
n, err := io.Copy(server, client)
|
||||
if *verbose && (err != nil || n == 0) {
|
||||
log.Printf("[relay-detail] client→server: %d bytes, err=%v", n, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// server → client (encrypted → raw MTProto)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
n, err := io.Copy(client, server)
|
||||
if *verbose && (err != nil || n == 0) {
|
||||
log.Printf("[relay-detail] server→client: %d bytes, err=%v", n, err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
client.Close()
|
||||
server.Close()
|
||||
}
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
"github.com/gotd/td/mtproxy"
|
||||
)
|
||||
|
||||
// ParseSecret decodes an ee-prefixed hex secret string.
|
||||
// Format (tdesktop-compatible): ee + tag(1) + secret(16) + sni(rest)
|
||||
// Tag byte may not match standard codec tags — we force PaddedIntermediate.
|
||||
func ParseSecret(hexStr string) (mtproxy.Secret, error) {
|
||||
if len(hexStr) < 4 || hexStr[:2] != "ee" {
|
||||
return mtproxy.Secret{}, fmt.Errorf("secret must start with 'ee' (FakeTLS mode)")
|
||||
}
|
||||
|
||||
raw, err := hex.DecodeString(hexStr[2:])
|
||||
if err != nil {
|
||||
return mtproxy.Secret{}, fmt.Errorf("invalid hex: %w", err)
|
||||
}
|
||||
|
||||
if len(raw) < 17 {
|
||||
return mtproxy.Secret{}, fmt.Errorf("secret too short: need 1+16+sni bytes, got %d", len(raw))
|
||||
}
|
||||
|
||||
// mtg format (confirmed working with mtproto.ru servers):
|
||||
// raw[0:16] = secret key (includes tag byte as part of key)
|
||||
// raw[16:] = SNI domain
|
||||
// Force PaddedIntermediate (0xdd) as protocol tag.
|
||||
return mtproxy.Secret{
|
||||
Secret: raw[0:16],
|
||||
Tag: 0xdd,
|
||||
CloakHost: string(raw[16:]),
|
||||
Type: mtproxy.TLS,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,346 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// DNS cache to survive temporary resolver failures.
|
||||
var (
|
||||
dnsCache sync.Map // domain → *dnsCacheEntry
|
||||
dnsCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
type dnsCacheEntry struct {
|
||||
ip string
|
||||
ts time.Time
|
||||
}
|
||||
|
||||
// resolveIPCached resolves a hostname with caching, preferring IPv4 but supporting IPv6.
|
||||
func resolveIPCached(host string) (string, error) {
|
||||
// Check cache
|
||||
if val, ok := dnsCache.Load(host); ok {
|
||||
entry := val.(*dnsCacheEntry)
|
||||
if time.Since(entry.ts) < dnsCacheTTL {
|
||||
return entry.ip, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try resolving (prefers IPv4, falls back to IPv6)
|
||||
newIP, err := resolveIP(host)
|
||||
if err != nil {
|
||||
// DNS failed — use stale cache if available (max 1 hour)
|
||||
if val, ok := dnsCache.Load(host); ok {
|
||||
entry := val.(*dnsCacheEntry)
|
||||
if time.Since(entry.ts) < 1*time.Hour {
|
||||
if *verbose {
|
||||
log.Printf("[debug] DNS failed for %s, using cached %s (age %s)", host, entry.ip, time.Since(entry.ts))
|
||||
}
|
||||
return entry.ip, nil
|
||||
}
|
||||
if *verbose {
|
||||
log.Printf("[debug] DNS failed for %s, stale cache expired (age %s)", host, time.Since(entry.ts))
|
||||
}
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Update cache atomically
|
||||
dnsCache.Store(host, &dnsCacheEntry{ip: newIP, ts: time.Now()})
|
||||
return newIP, nil
|
||||
}
|
||||
|
||||
// handleTransparent redirects intercepted Telegram traffic through
|
||||
// Cloudflare WebSocket. Optimized for throughput:
|
||||
// - TCP_NODELAY on client connection (disable Nagle)
|
||||
// - Buffered reads with flush coalescing (reduce WS frame count)
|
||||
// - Large WebSocket write/read buffers
|
||||
func handleTransparent(ctx context.Context, clientConn *net.TCPConn) {
|
||||
defer clientConn.Close()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[panic] %s: %v", clientConn.RemoteAddr(), r)
|
||||
}
|
||||
}()
|
||||
|
||||
origIP, _, err := getOriginalDst(clientConn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dc := LookupDC(origIP)
|
||||
isMedia := false
|
||||
|
||||
if *verbose {
|
||||
log.Printf("[conn] %s -> DC%d (%s)", clientConn.RemoteAddr(), dc, origIP)
|
||||
}
|
||||
|
||||
// Performance: disable Nagle's algorithm — send data immediately
|
||||
clientConn.SetNoDelay(true)
|
||||
clientConn.SetDeadline(time.Now().Add(*connTimeout))
|
||||
|
||||
// Connect via WebSocket with retry
|
||||
var ws *websocket.Conn
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
ws, err = connectWSTransparent(int(dc), isMedia)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if attempt < 2 {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if *verbose {
|
||||
log.Printf("[error] WS DC%d: %v", dc, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
// Performance: enable WebSocket compression if server supports it
|
||||
ws.EnableWriteCompression(true)
|
||||
|
||||
writer := &wsWriter{ws: ws}
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Keepalive: CF kills idle WS after 100s. Ping every 50s.
|
||||
go func() {
|
||||
ticker := time.NewTicker(50 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
writer.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second))
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if *verbose {
|
||||
log.Printf("[relay] %s <-> WS DC%d", clientConn.RemoteAddr(), dc)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// client → WS: buffered reader coalesces small TCP segments into larger WS frames
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
reader := bufio.NewReaderSize(clientConn, 128*1024) // 128KB read buffer
|
||||
buf := make([]byte, 128*1024)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read as much as available (buffered — coalesces small segments)
|
||||
n, err := reader.Read(buf)
|
||||
if n > 0 {
|
||||
clientConn.SetDeadline(time.Now().Add(*connTimeout))
|
||||
if werr := writer.WriteMessage(websocket.BinaryMessage, buf[:n]); werr != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// WS → client: direct write with write buffer
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
clientWriter := bufio.NewWriterSize(clientConn, 128*1024) // 128KB write buffer
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
_, msg, rerr := ws.ReadMessage()
|
||||
if rerr != nil {
|
||||
if *verbose && websocket.IsUnexpectedCloseError(rerr, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
|
||||
log.Printf("[debug] WS read error: %v", rerr)
|
||||
}
|
||||
break
|
||||
}
|
||||
if len(msg) > 0 {
|
||||
clientConn.SetDeadline(time.Now().Add(*connTimeout))
|
||||
if _, werr := clientWriter.Write(msg); werr != nil {
|
||||
break
|
||||
}
|
||||
// Flush immediately if buffer has enough data or ws has no more pending
|
||||
clientWriter.Flush()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if *verbose {
|
||||
log.Printf("[done] %s DC%d", clientConn.RemoteAddr(), dc)
|
||||
}
|
||||
}
|
||||
|
||||
func connectWSTransparent(dc int, isMedia bool) (*websocket.Conn, error) {
|
||||
cfDomain := fmt.Sprintf("kws%d.pclead.co.uk", dc)
|
||||
|
||||
ip, err := resolveIPCached(cfDomain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve %s: %w", cfDomain, err)
|
||||
}
|
||||
|
||||
// Determine dial network and address format based on IP version
|
||||
dialNetwork := "tcp4"
|
||||
dialAddr := ip + ":443"
|
||||
if net.ParseIP(ip) != nil && net.ParseIP(ip).To4() == nil {
|
||||
dialNetwork = "tcp6"
|
||||
dialAddr = "[" + ip + "]:443"
|
||||
}
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
TLSClientConfig: &tls.Config{
|
||||
ServerName: cfDomain,
|
||||
},
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
Subprotocols: []string{"binary"},
|
||||
ReadBufferSize: 128 * 1024, // 128KB WS read buffer
|
||||
WriteBufferSize: 128 * 1024, // 128KB WS write buffer
|
||||
EnableCompression: true, // per-message deflate
|
||||
NetDial: func(network, addr string) (net.Conn, error) {
|
||||
conn, err := net.DialTimeout(dialNetwork, dialAddr, 5*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TCP_NODELAY on WS connection too
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
tcpConn.SetNoDelay(true)
|
||||
}
|
||||
return conn, nil
|
||||
},
|
||||
}
|
||||
headers := http.Header{}
|
||||
headers.Set("Origin", "http://web.telegram.org")
|
||||
headers.Set("Host", cfDomain)
|
||||
|
||||
url := fmt.Sprintf("wss://%s/apiws", cfDomain)
|
||||
ws, _, err := dialer.Dial(url, headers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial %s (%s): %w", cfDomain, ip, err)
|
||||
}
|
||||
|
||||
// Set read limit to prevent memory exhaustion
|
||||
ws.SetReadLimit(2 * 1024 * 1024) // 2MB for media
|
||||
|
||||
if *verbose {
|
||||
log.Printf("[debug] WS connected to %s (%s)", cfDomain, ip)
|
||||
}
|
||||
return ws, nil
|
||||
}
|
||||
|
||||
// transparentListener runs the transparent proxy mode with graceful shutdown.
|
||||
func transparentListener(listenAddr string) error {
|
||||
ln, err := net.Listen("tcp", listenAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize connection limiter
|
||||
connSemaphore = make(chan struct{}, *maxConns)
|
||||
|
||||
// Pre-warm DNS cache for all DCs
|
||||
for _, dc := range []int{1, 2, 3, 4, 5} {
|
||||
domain := fmt.Sprintf("kws%d.pclead.co.uk", dc)
|
||||
if ip, err := resolveIP(domain); err == nil {
|
||||
dnsCache.Store(domain, &dnsCacheEntry{ip: ip, ts: time.Now()})
|
||||
log.Printf("DNS cache: %s -> %s", domain, ip)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("tg-transparent-proxy listening on %s", listenAddr)
|
||||
|
||||
// Graceful shutdown
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
// Periodic DNS cache refresh
|
||||
go func() {
|
||||
ticker := time.NewTicker(dnsCacheTTL)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
for _, dc := range []int{1, 2, 3, 4, 5} {
|
||||
domain := fmt.Sprintf("kws%d.pclead.co.uk", dc)
|
||||
if ip, err := resolveIP(domain); err == nil {
|
||||
dnsCache.Store(domain, &dnsCacheEntry{ip: ip, ts: time.Now()})
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
log.Println("[shutdown] Closing listener...")
|
||||
ln.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Println("[shutdown] Transparent proxy stopped")
|
||||
return nil
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// Rate limit connections
|
||||
select {
|
||||
case connSemaphore <- struct{}{}:
|
||||
go func() {
|
||||
defer func() { <-connSemaphore }()
|
||||
handleTransparent(ctx, tcpConn)
|
||||
}()
|
||||
default:
|
||||
if *verbose {
|
||||
log.Printf("[warn] max connections reached, rejecting %s", conn.RemoteAddr())
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -23,10 +23,10 @@ import (
|
|||
|
||||
// Mux message types
|
||||
const (
|
||||
muxCONNECT = 0x01
|
||||
muxDATA = 0x02
|
||||
muxCLOSE = 0x03
|
||||
muxCONNECT_OK = 0x04
|
||||
muxCONNECT = 0x01
|
||||
muxDATA = 0x02
|
||||
muxCLOSE = 0x03
|
||||
muxCONNECT_OK = 0x04
|
||||
muxCONNECT_FAIL = 0x05
|
||||
)
|
||||
|
||||
|
|
@ -83,33 +83,6 @@ func encodeConnectPayload(ip net.IP, port int) []byte {
|
|||
return buf
|
||||
}
|
||||
|
||||
// decodeConnectPayload parses a CONNECT payload into IP and port.
|
||||
func decodeConnectPayload(data []byte) (net.IP, int, error) {
|
||||
if len(data) < 1 {
|
||||
return nil, 0, fmt.Errorf("empty CONNECT payload")
|
||||
}
|
||||
switch data[0] {
|
||||
case addrIPv4:
|
||||
if len(data) < 7 {
|
||||
return nil, 0, fmt.Errorf("IPv4 CONNECT payload too short: %d", len(data))
|
||||
}
|
||||
ip := net.IP(make([]byte, 4))
|
||||
copy(ip, data[1:5])
|
||||
port := int(binary.BigEndian.Uint16(data[5:7]))
|
||||
return ip, port, nil
|
||||
case addrIPv6:
|
||||
if len(data) < 19 {
|
||||
return nil, 0, fmt.Errorf("IPv6 CONNECT payload too short: %d", len(data))
|
||||
}
|
||||
ip := net.IP(make([]byte, 16))
|
||||
copy(ip, data[1:17])
|
||||
port := int(binary.BigEndian.Uint16(data[17:19]))
|
||||
return ip, port, nil
|
||||
default:
|
||||
return nil, 0, fmt.Errorf("unknown addr type: %d", data[0])
|
||||
}
|
||||
}
|
||||
|
||||
// computeAuthHMAC computes the HMAC-SHA256 of the shared secret (keyed by itself).
|
||||
func computeAuthHMAC(secret string) []byte {
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
|
|
@ -122,27 +95,38 @@ type tunnelClient struct {
|
|||
tunnelURL string
|
||||
tunnelSecret string
|
||||
|
||||
ws *websocket.Conn
|
||||
writer *wsWriter
|
||||
streams sync.Map // uint16 → *tunnelStream
|
||||
nextID atomic.Uint32
|
||||
mu sync.Mutex // protects ws/writer replacement during reconnect
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ws *websocket.Conn
|
||||
writer *wsWriter
|
||||
streams sync.Map // uint16 → *tunnelStream
|
||||
nextID atomic.Uint32
|
||||
mu sync.Mutex // protects ws/writer replacement during reconnect
|
||||
connectSem chan struct{} // limits concurrent CONNECT to CF Workers limit
|
||||
wsReady chan struct{} // closed when WS is connected, recreated on disconnect
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type tunnelStream struct {
|
||||
id uint16
|
||||
conn *net.TCPConn
|
||||
client *tunnelClient
|
||||
id uint16
|
||||
conn *net.TCPConn
|
||||
client *tunnelClient
|
||||
origIP net.IP // original destination IP (for re-CONNECT after reconnect)
|
||||
origPort int // original destination port
|
||||
closeOnce sync.Once
|
||||
upBytes atomic.Int64
|
||||
downBytes atomic.Int64
|
||||
connected atomic.Bool // true after first CONNECT_OK
|
||||
}
|
||||
|
||||
func (s *tunnelStream) close() {
|
||||
s.closeOnce.Do(func() {
|
||||
up := s.upBytes.Load()
|
||||
down := s.downBytes.Load()
|
||||
if *verbose {
|
||||
log.Printf("[tunnel] stream %d closed (up=%d down=%d)", s.id, up, down)
|
||||
}
|
||||
s.conn.Close()
|
||||
s.client.streams.Delete(s.id)
|
||||
// Send CLOSE frame (best effort)
|
||||
s.client.mu.Lock()
|
||||
w := s.client.writer
|
||||
s.client.mu.Unlock()
|
||||
|
|
@ -165,7 +149,6 @@ func (tc *tunnelClient) connectTunnelWS() (*websocket.Conn, error) {
|
|||
WriteBufferSize: 128 * 1024,
|
||||
EnableCompression: true,
|
||||
NetDial: func(network, addr string) (net.Conn, error) {
|
||||
// Force IPv4 — IPv6 to Cloudflare is unstable on some ISPs
|
||||
conn, err := net.DialTimeout("tcp4", addr, 10*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -197,14 +180,33 @@ func (tc *tunnelClient) connectTunnelWS() (*websocket.Conn, error) {
|
|||
return ws, nil
|
||||
}
|
||||
|
||||
// closeAllStreams closes all active tunnel streams.
|
||||
func (tc *tunnelClient) closeAllStreams() {
|
||||
// reConnectStreams re-sends CONNECT for all surviving streams after WS reconnect.
|
||||
func (tc *tunnelClient) reConnectStreams() {
|
||||
tc.mu.Lock()
|
||||
w := tc.writer
|
||||
tc.mu.Unlock()
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
|
||||
count := 0
|
||||
tc.streams.Range(func(key, value any) bool {
|
||||
stream := value.(*tunnelStream)
|
||||
stream.conn.Close()
|
||||
tc.streams.Delete(key)
|
||||
stream.connected.Store(false)
|
||||
|
||||
connectPayload := encodeConnectPayload(stream.origIP, stream.origPort)
|
||||
frame := encodeMuxFrame(stream.id, muxCONNECT, connectPayload)
|
||||
if err := w.WriteMessage(websocket.BinaryMessage, frame); err != nil {
|
||||
log.Printf("[tunnel] stream %d re-CONNECT write error: %v", stream.id, err)
|
||||
stream.close()
|
||||
return true
|
||||
}
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
log.Printf("[tunnel] re-CONNECTed %d surviving streams", count)
|
||||
}
|
||||
}
|
||||
|
||||
// readLoop reads mux frames from the WS and dispatches to streams.
|
||||
|
|
@ -218,6 +220,9 @@ func (tc *tunnelClient) readLoop(ws *websocket.Conn) {
|
|||
return
|
||||
}
|
||||
|
||||
// Any incoming message means WS is alive — extend read deadline
|
||||
ws.SetReadDeadline(time.Now().Add(120 * time.Second))
|
||||
|
||||
frame, err := decodeMuxFrame(msg)
|
||||
if err != nil {
|
||||
if *verbose {
|
||||
|
|
@ -237,6 +242,7 @@ func (tc *tunnelClient) readLoop(ws *websocket.Conn) {
|
|||
|
||||
switch frame.MsgType {
|
||||
case muxDATA:
|
||||
stream.downBytes.Add(int64(len(frame.Payload)))
|
||||
stream.conn.SetDeadline(time.Now().Add(*connTimeout))
|
||||
if _, err := stream.conn.Write(frame.Payload); err != nil {
|
||||
if *verbose {
|
||||
|
|
@ -253,12 +259,20 @@ func (tc *tunnelClient) readLoop(ws *websocket.Conn) {
|
|||
tc.streams.Delete(frame.StreamID)
|
||||
|
||||
case muxCONNECT_OK:
|
||||
select {
|
||||
case <-tc.connectSem:
|
||||
default:
|
||||
}
|
||||
stream.connected.Store(true)
|
||||
if *verbose {
|
||||
log.Printf("[tunnel] stream %d CONNECT_OK", frame.StreamID)
|
||||
}
|
||||
// streamReadLoop already started in handleTunnelConn
|
||||
|
||||
case muxCONNECT_FAIL:
|
||||
select {
|
||||
case <-tc.connectSem:
|
||||
default:
|
||||
}
|
||||
log.Printf("[tunnel] stream %d CONNECT_FAIL", frame.StreamID)
|
||||
stream.conn.Close()
|
||||
tc.streams.Delete(frame.StreamID)
|
||||
|
|
@ -272,6 +286,7 @@ func (tc *tunnelClient) readLoop(ws *websocket.Conn) {
|
|||
}
|
||||
|
||||
// streamReadLoop reads from a TCP client and sends DATA frames over WS.
|
||||
// Survives WS reconnects: waits for writer to become available again.
|
||||
func (tc *tunnelClient) streamReadLoop(stream *tunnelStream) {
|
||||
defer stream.close()
|
||||
|
||||
|
|
@ -281,19 +296,28 @@ func (tc *tunnelClient) streamReadLoop(stream *tunnelStream) {
|
|||
for {
|
||||
n, err := reader.Read(buf)
|
||||
if n > 0 {
|
||||
stream.upBytes.Add(int64(n))
|
||||
stream.conn.SetDeadline(time.Now().Add(*connTimeout))
|
||||
frame := encodeMuxFrame(stream.id, muxDATA, buf[:n])
|
||||
tc.mu.Lock()
|
||||
w := tc.writer
|
||||
tc.mu.Unlock()
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
if werr := w.WriteMessage(websocket.BinaryMessage, frame); werr != nil {
|
||||
if *verbose {
|
||||
log.Printf("[tunnel] stream %d WS write error: %v", stream.id, werr)
|
||||
|
||||
// Wait for WS to be available (survives reconnect)
|
||||
for attempt := 0; attempt < 50; attempt++ {
|
||||
tc.mu.Lock()
|
||||
w := tc.writer
|
||||
tc.mu.Unlock()
|
||||
if w != nil {
|
||||
if werr := w.WriteMessage(websocket.BinaryMessage, frame); werr != nil {
|
||||
if *verbose {
|
||||
log.Printf("[tunnel] stream %d WS write error: %v", stream.id, werr)
|
||||
}
|
||||
// Write failed — WS probably just died, wait for reconnect
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
break // success
|
||||
}
|
||||
return
|
||||
// No writer — WS is reconnecting, wait
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
|
|
@ -327,6 +351,17 @@ func (tc *tunnelClient) run() {
|
|||
tc.writer = &wsWriter{ws: ws}
|
||||
tc.mu.Unlock()
|
||||
|
||||
// PongHandler: update read deadline when pong received
|
||||
ws.SetPongHandler(func(appData string) error {
|
||||
ws.SetReadDeadline(time.Now().Add(120 * time.Second))
|
||||
return nil
|
||||
})
|
||||
// Initial read deadline
|
||||
ws.SetReadDeadline(time.Now().Add(120 * time.Second))
|
||||
|
||||
// Re-CONNECT surviving streams from previous WS session
|
||||
tc.reConnectStreams()
|
||||
|
||||
// Keepalive: ping every 50s (CF kills idle WS after 100s)
|
||||
pingDone := make(chan struct{})
|
||||
go func() {
|
||||
|
|
@ -351,14 +386,23 @@ func (tc *tunnelClient) run() {
|
|||
// Read loop blocks until WS disconnects
|
||||
tc.readLoop(ws)
|
||||
|
||||
// WS disconnected — clean up
|
||||
log.Printf("[tunnel] WS disconnected, closing all streams")
|
||||
// WS disconnected — DON'T close client TCP connections
|
||||
log.Printf("[tunnel] WS disconnected, keeping streams alive for reconnect")
|
||||
tc.mu.Lock()
|
||||
tc.ws = nil
|
||||
tc.writer = nil
|
||||
tc.mu.Unlock()
|
||||
ws.Close()
|
||||
tc.closeAllStreams()
|
||||
|
||||
// Drain connect semaphore — pending CONNECTs died with the WS
|
||||
for {
|
||||
select {
|
||||
case <-tc.connectSem:
|
||||
default:
|
||||
goto drained
|
||||
}
|
||||
}
|
||||
drained:
|
||||
|
||||
// Wait for ping goroutine
|
||||
select {
|
||||
|
|
@ -390,25 +434,45 @@ func (tc *tunnelClient) handleTunnelConn(clientConn *net.TCPConn) {
|
|||
return
|
||||
}
|
||||
|
||||
// Allocate stream ID (wrap around at 65535)
|
||||
rawID := tc.nextID.Add(1)
|
||||
streamID := uint16(rawID % 65535) + 1 // 1..65535, avoid 0
|
||||
// Allocate stream ID — skip IDs still in use (prevents wrap-around collision)
|
||||
var streamID uint16
|
||||
for i := 0; i < 100; i++ {
|
||||
rawID := tc.nextID.Add(1)
|
||||
streamID = uint16(rawID%65535) + 1
|
||||
if _, exists := tc.streams.Load(streamID); !exists {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Wait up to 5s for WS to be ready (handles new connections during reconnect)
|
||||
tc.mu.Lock()
|
||||
w := tc.writer
|
||||
tc.mu.Unlock()
|
||||
if w == nil {
|
||||
if *verbose {
|
||||
log.Printf("[tunnel] no WS connection, dropping stream %d", streamID)
|
||||
for i := 0; i < 50; i++ {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
tc.mu.Lock()
|
||||
w = tc.writer
|
||||
tc.mu.Unlock()
|
||||
if w != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if w == nil {
|
||||
if *verbose {
|
||||
log.Printf("[tunnel] no WS connection after waiting, dropping stream %d", streamID)
|
||||
}
|
||||
clientConn.Close()
|
||||
return
|
||||
}
|
||||
clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
stream := &tunnelStream{
|
||||
id: streamID,
|
||||
conn: clientConn,
|
||||
client: tc,
|
||||
id: streamID,
|
||||
conn: clientConn,
|
||||
client: tc,
|
||||
origIP: origIP,
|
||||
origPort: origPort,
|
||||
}
|
||||
tc.streams.Store(streamID, stream)
|
||||
|
||||
|
|
@ -416,19 +480,27 @@ func (tc *tunnelClient) handleTunnelConn(clientConn *net.TCPConn) {
|
|||
log.Printf("[tunnel] stream %d: %s -> %s:%d", streamID, clientConn.RemoteAddr(), origIP, origPort)
|
||||
}
|
||||
|
||||
// Rate-limit concurrent CONNECTs to stay within CF Workers 6-connection limit
|
||||
select {
|
||||
case tc.connectSem <- struct{}{}:
|
||||
case <-time.After(10 * time.Second):
|
||||
log.Printf("[tunnel] stream %d CONNECT throttled (timeout)", streamID)
|
||||
stream.conn.Close()
|
||||
tc.streams.Delete(streamID)
|
||||
return
|
||||
}
|
||||
|
||||
// Send CONNECT frame
|
||||
connectPayload := encodeConnectPayload(origIP, origPort)
|
||||
frame := encodeMuxFrame(streamID, muxCONNECT, connectPayload)
|
||||
if err := w.WriteMessage(websocket.BinaryMessage, frame); err != nil {
|
||||
<-tc.connectSem
|
||||
log.Printf("[tunnel] stream %d CONNECT write error: %v", streamID, err)
|
||||
stream.conn.Close()
|
||||
tc.streams.Delete(streamID)
|
||||
return
|
||||
}
|
||||
|
||||
// Start reading from client immediately — data will be buffered
|
||||
// in WS until Worker's TCP connect completes. Worker queues DATA
|
||||
// frames and writes them after socket.opened resolves.
|
||||
go tc.streamReadLoop(stream)
|
||||
}
|
||||
|
||||
|
|
@ -456,10 +528,10 @@ func runTunnel() error {
|
|||
tc := &tunnelClient{
|
||||
tunnelURL: *tunnelURL,
|
||||
tunnelSecret: *tunnelSecret,
|
||||
connectSem: make(chan struct{}, 6),
|
||||
}
|
||||
tc.ctx, tc.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Start persistent WS connection manager
|
||||
go tc.run()
|
||||
|
||||
go func() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue