Commit 9da9ed8
2025-04-03 05:55:14
Changed files (11)
go/connection.go
@@ -0,0 +1,91 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "net"
+
+ "github.com/google/uuid"
+ "golang.org/x/sync/errgroup"
+)
+
+type connection struct {
+ id uuid.UUID
+ dAddr AddrRequest
+ sConn, dConn *net.TCPConn
+ ctx context.Context
+ cancel context.CancelFunc
+ errgroup *errgroup.Group
+}
+
+// TODO: Options
+// - higher context
+func NewConnection(
+ sConn *net.TCPConn,
+) (*connection, error) {
+ ctx := context.Background()
+
+ groupCtx, cancel := context.WithCancel(ctx)
+ g, connCtx := errgroup.WithContext(groupCtx)
+
+ c := &connection{
+ id: uuid.New(),
+ sConn: sConn,
+ ctx: connCtx,
+ cancel: cancel,
+ errgroup: g,
+ }
+
+ err := handleClientIdentifier(sConn)
+ if err != nil {
+ return nil, fmt.Errorf("failed to handle socks5 client id: %w", err)
+ }
+
+ c.dAddr, err = ParseClientRequest(sConn)
+ if err != nil {
+ return nil, fmt.Errorf("failed to handle socks5 client request: %w", err)
+ }
+
+ err = c.DialRemote()
+ if err != nil {
+ return nil, fmt.Errorf("failed to dial remote: %w", err)
+ }
+
+ _, err = sConn.Write(replyV4(ReplySucceeded))
+
+ return c, nil
+}
+
+func (c *connection) Id() string { return c.id.String() }
+
+func (c *connection) DialRemote() error {
+
+ var tcpAddr *net.TCPAddr
+ var err error
+
+ if c.dAddr.Domain != nil {
+ tcpAddr, err = net.ResolveTCPAddr("tcp", *c.dAddr.Domain)
+ if err != nil {
+ return err
+ }
+ }
+ if c.dAddr.AddrPort != nil {
+ tcpAddr = net.TCPAddrFromAddrPort(*c.dAddr.AddrPort)
+ }
+
+ conn, err := net.DialTCP("tcp", nil, tcpAddr)
+ if err != nil {
+ return err
+ }
+ err = conn.SetReadBuffer(bufSize)
+ if err != nil {
+ return err
+ }
+ err = conn.SetWriteBuffer(bufSize)
+ if err != nil {
+ return err
+ }
+
+ c.dConn = conn
+ return nil
+}
go/direction.go
@@ -0,0 +1,12 @@
+package main
+
+type direction string
+
+func (d direction) String() string {
+ return string(d)
+}
+
+const (
+ send direction = "send"
+ recv direction = "recv"
+)
go/go.mod
@@ -0,0 +1,11 @@
+module stitch
+
+go 1.23.0
+
+toolchain go1.24.1
+
+require (
+ github.com/dustin/go-humanize v1.0.1 // indirect
+ github.com/google/uuid v1.6.0 // indirect
+ golang.org/x/sync v0.12.0 // indirect
+)
go/go.sum
@@ -0,0 +1,6 @@
+github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
+github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
+golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
go/main.go
@@ -0,0 +1,82 @@
+package main
+
+import (
+ "fmt"
+ "log/slog"
+ "net"
+ "time"
+)
+
+// TODO:
+// - [x] parsing socks5 bytes into functions
+// - [ ] report results and cobine in send+recv results
+// - connection id
+// - [] seralization of replies into functions ->
+// - [] better goroutine mgmt
+// - [] signal handlers
+// - [] split server and client with abstract reader/writers and messages
+// - (?) avoid structs, [id, len, data] for established connections
+// - (?) server-to-server Reply ATYP=05 id
+
+// TODO: options
+
+const (
+ timeout = time.Second * 30
+ bufSize = 128 * 1024
+)
+
+func NewServer() (*net.TCPListener, error) {
+ network := "tcp4"
+ address := "127.0.0.1"
+ port := 9000
+ slog.Info("starting socks server",
+ slog.String("network", network),
+ slog.String("address", address),
+ slog.Int("port", port),
+ )
+ addr := fmt.Sprintf("%s:%d", address, port)
+ lAddr, err := net.ResolveTCPAddr("tcp", addr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse addr=%q: %w", addr, err)
+ }
+ l, err := net.ListenTCP(network, lAddr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to start server: %w", err)
+ }
+ return l, nil
+}
+
+func main() {
+
+ l, err := NewServer()
+ if err != nil {
+ fmt.Println(err)
+ return
+ }
+ defer l.Close()
+
+ for {
+ sconn, err := l.AcceptTCP()
+ if err != nil {
+ slog.Error("failed to accept conenction",
+ slog.String("error", err.Error()))
+ return
+ }
+
+ go func() {
+ conn, err := NewConnection(sconn)
+ if err != nil {
+ slog.Error("failed to setup proxy",
+ slog.String("error", err.Error()))
+ return
+ }
+
+ err = conn.Proxy()
+ if err != nil {
+ slog.Error("failed to proxy connection",
+ slog.String("error", err.Error()))
+ return
+ }
+ }()
+ }
+}
go/proxy.go
@@ -0,0 +1,92 @@
+package main
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "net"
+ "time"
+
+ "github.com/dustin/go-humanize"
+)
+
+func (c connection) Proxy() error {
+ c.proxyData(send)
+ c.proxyData(recv)
+ return c.errgroup.Wait()
+}
+
+func (c connection) proxyData(d direction) {
+
+ c.errgroup.Go(func() error {
+
+ // send/recv direction
+ id := c.Id()
+ name := d.String()
+ src, dst := c.sConn, c.dConn
+ if d == recv {
+ src, dst = c.dConn, c.sConn
+ }
+
+ start := time.Now()
+ lastByte := time.Now()
+ bytes := uint64(0)
+
+ defer func() {
+ c.cancel()
+ speed := uint64(0)
+ elapsed := time.Since(start).Round(time.Second)
+ if elapsed > 0 {
+ speed = bytes / uint64(elapsed.Seconds())
+ }
+ elapsedLB := time.Since(lastByte).Round(time.Second)
+ if elapsedLB > 0 {
+ speed = bytes / uint64(elapsedLB.Seconds())
+ }
+ human_speed := humanize.Bytes(speed)
+ slog.Info("complete",
+ slog.String("conn", d.String()),
+ slog.String("speed", human_speed+"/s"),
+ slog.Uint64("bytes", bytes),
+ slog.Duration("duration", elapsed),
+ slog.Duration("durationLB", elapsedLB))
+ }()
+
+ buf := make([]byte, bufSize)
+ for {
+ select {
+ case <-c.ctx.Done():
+ err := src.Close()
+ if err != nil {
+ return fmt.Errorf("proxy: %s failed to close id=%q: %w",
+ name, id, err)
+ }
+ return nil
+ default:
+ src.SetReadDeadline(time.Now().Add(timeout))
+ n, err := src.Read(buf)
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ return nil
+ }
+ opErr, ok := err.(*net.OpError)
+ if ok && opErr.Timeout() {
+ return fmt.Errorf("proxy: %s timeout id=%q: %w",
+ name, id, err)
+ }
+ return fmt.Errorf("proxy: %s failed to read id=%q: %w",
+ name, id, err)
+ }
+ lastByte = time.Now()
+ bytes += uint64(n)
+ dst.SetWriteDeadline(time.Now().Add(timeout))
+ _, err = dst.Write(buf[:n])
+ if err != nil {
+ return fmt.Errorf("proxy: %s failed to write id=%q: %w",
+ name, id, err)
+ }
+ }
+ }
+ })
+}
go/socks.go
@@ -0,0 +1,226 @@
+package main
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+ "net"
+ "net/netip"
+)
+
+const (
+ VersionSocks5 = byte(0x05)
+
+ MethodNoAuthRequired = byte(0x00)
+ MethodNoAcceptableMethod = byte(0xFF)
+
+ CommandConnect = byte(0x01)
+
+ AddressTypeIPV4 = byte(0x01)
+ AddressTypeDomain = byte(0x03)
+ AddressTypeIPV6 = byte(0x04)
+
+ ReplySucceeded = byte(0x00)
+ ReplyGeneralFailure = byte(0x01)
+
+ Reserved = byte(0x00)
+)
+
+var (
+ // Useful byte arrays
+ _zeroIPv4 = []byte(net.IPv4zero.To4())
+ _zeroPort = []byte{0x00, 0x00}
+)
+
+// AddrRequest represents a union type that holds either an netip.AddrPort
+// or a string. Only one of the fields (AddrPort or Domain) should be non-nil
+// at a time. This enables a socks5 parsed client request to return any of the
+// three valid address type destaniation addr information in a struct that is
+// useable without interface{} or reflection.
+type AddrRequest struct {
+ AddrPort *netip.AddrPort
+ Domain *string
+}
+
+// reply creates a byte slice for a socks5 reply (RFC 1928 Section 6)
+//
+// +----+-----+-------+------+----------+----------+
+// |VER | REP | RSV | ATYP | BND.ADDR | BND.PORT |
+// +----+-----+-------+------+----------+----------+
+// | 1 | 1 | X'00' | 1 | Variable | 2 |
+// +----+-----+-------+------+----------+----------+
+func replyV4(rep byte) []byte {
+ reply := make([]byte, 4, 10)
+ reply[0] = VersionSocks5
+ reply[1] = rep
+ reply[2] = Reserved
+ reply[3] = AddressTypeIPV4
+ reply = append(reply, _zeroIPv4...)
+ reply = append(reply, _zeroPort...)
+ return reply
+}
+
+// handleClientIdentifier
+//
+// +----+----------+----------+
+// |VER | NMETHODS | METHODS |
+// +----+----------+----------+
+// | 1 | 1 | 1 to 255 |
+// +----+----------+----------+
+func handleClientIdentifier(conn io.ReadWriter) error {
+
+ // version and nmethods
+ client := make([]byte, 2)
+ n, err := conn.Read(client)
+ if err != nil {
+ return fmt.Errorf("failed to read client identifier: %w", err)
+ }
+ if n != 2 {
+ return fmt.Errorf("failed to parse client identifier: bad length")
+ }
+
+ // only socks5 supported
+ ver := client[0]
+ if ver != VersionSocks5 {
+ _, err := conn.Write([]byte{VersionSocks5, MethodNoAcceptableMethod})
+ if err != nil {
+ return fmt.Errorf("failed to write to socket: %w", err)
+ }
+ return fmt.Errorf("version not supported version=%q", ver)
+ }
+
+ // methods
+ nmethods := int(client[1])
+ methods := make([]byte, nmethods)
+ n, err = conn.Read(methods)
+ if err != nil {
+ return fmt.Errorf("failed to read client methods: %w", err)
+ }
+ if n != nmethods {
+ return fmt.Errorf("failed to parse client methods: bad length")
+ }
+
+ // only no auth supported
+ match := false
+ for _, m := range methods {
+ if m == MethodNoAuthRequired {
+ match = true
+ break
+ }
+ }
+ if !match {
+ _, err := conn.Write([]byte{VersionSocks5, MethodNoAcceptableMethod})
+ if err != nil {
+ return fmt.Errorf("failed to write to socket: %w", err)
+ }
+ return fmt.Errorf("no method compatibility found")
+ }
+
+ _, err = conn.Write([]byte{VersionSocks5, MethodNoAuthRequired})
+ return nil
+}
+
+func ParseClientRequest(conn io.ReadWriter) (AddrRequest, error) {
+ var (
+ ar AddrRequest
+ addr netip.Addr
+ domain string
+ )
+
+ req := make([]byte, 4) // version and count
+ n, err := conn.Read(req)
+ if err != nil {
+ return ar, fmt.Errorf("failed to read client request: %w", err)
+ }
+ if n != 4 {
+ return ar, fmt.Errorf("failed to parse client request: bad length")
+ }
+
+ // socks5 and connect only
+ ver, cmd := req[0], req[1]
+ if ver != VersionSocks5 || cmd != CommandConnect {
+ _, err := conn.Write([]byte{VersionSocks5, ReplyGeneralFailure})
+ if err != nil {
+ return ar, fmt.Errorf("failed to write to socket: %w", err)
+ }
+ return ar, fmt.Errorf("unspported request ver=%q cmd=%q",
+ ver, cmd)
+ }
+
+ addrType := req[3]
+ switch addrType {
+
+ case AddressTypeIPV4:
+
+ ipRaw := [net.IPv4len]byte{}
+ n, err := conn.Read(ipRaw[:])
+ if err != nil {
+ return ar, fmt.Errorf("failed to read dst address: %w", err)
+ }
+ if n != net.IPv4len {
+ return ar, fmt.Errorf("failed to parse dst address: bad length")
+ }
+ addr = netip.AddrFrom4(ipRaw)
+
+ case AddressTypeIPV6:
+
+ ipRaw := [net.IPv6len]byte{}
+ n, err := conn.Read(ipRaw[:])
+ if err != nil {
+ return ar, fmt.Errorf("failed to read dst address: %w", err)
+ }
+ if n != net.IPv6len {
+ return ar, fmt.Errorf("failed to parse dst address: bad length")
+ }
+
+ addr = netip.AddrFrom16(ipRaw)
+
+ case AddressTypeDomain:
+
+ domainLen := make([]byte, 1)
+ n, err := conn.Read(domainLen)
+ if err != nil {
+ return ar, fmt.Errorf("failed to read dst domain length: %w", err)
+ }
+ if n != 1 {
+ return ar, fmt.Errorf("failed to parse dst domain length: bad length")
+ }
+ domainRaw := make([]byte, int(domainLen[0]))
+ n, err = conn.Read(domainRaw)
+ if err != nil {
+ return ar, fmt.Errorf("failed to read dst domain: %w", err)
+ }
+ if n != int(domainLen[0]) {
+ return ar, fmt.Errorf("failed to parse dst domain: bad length")
+ }
+ domain = string(domainRaw)
+
+ default:
+
+ return ar, fmt.Errorf("unknown address type")
+ }
+
+ portRaw := make([]byte, 2)
+ n, err = conn.Read(portRaw)
+ if err != nil {
+ return ar, fmt.Errorf("failed to read dst port: %w", err)
+ }
+ if n != 2 {
+ return ar, fmt.Errorf("failed to read dst port: bad length")
+ }
+ // destination port in network octet order
+ port := binary.BigEndian.Uint16(portRaw)
+
+ switch addrType {
+ case AddressTypeIPV4, AddressTypeIPV6:
+ addrPort := netip.AddrPortFrom(addr, port)
+ ar.AddrPort = &addrPort
+ return ar, nil
+ case AddressTypeDomain:
+ domainPort := fmt.Sprintf("%s:%d", domain, port)
+ ar.Domain = &domainPort
+ return ar, nil
+ default:
+ return ar, fmt.Errorf("unknown address type")
+ }
+}
zig/src/main.zig
@@ -0,0 +1,24 @@
+const std = @import("std");
+
+pub fn main() !void {
+ // Prints to stderr (it's a shortcut based on `std.io.getStdErr()`)
+ std.debug.print("All your {s} are belong to us.\n", .{"codebase"});
+
+ // stdout is for the actual output of your application, for example if you
+ // are implementing gzip, then only the compressed bytes should be sent to
+ // stdout, not any debugging messages.
+ const stdout_file = std.io.getStdOut().writer();
+ var bw = std.io.bufferedWriter(stdout_file);
+ const stdout = bw.writer();
+
+ try stdout.print("Run `zig build test` to run the tests.\n", .{});
+
+ try bw.flush(); // don't forget to flush!
+}
+
+test "simple test" {
+ var list = std.ArrayList(i32).init(std.testing.allocator);
+ defer list.deinit(); // try commenting this out and see if zig detects the memory leak!
+ try list.append(42);
+ try std.testing.expectEqual(@as(i32, 42), list.pop());
+}
zig/src/root.zig
@@ -0,0 +1,10 @@
+const std = @import("std");
+const testing = std.testing;
+
+export fn add(a: i32, b: i32) i32 {
+ return a + b;
+}
+
+test "basic add functionality" {
+ try testing.expect(add(3, 7) == 10);
+}
zig/build.zig
@@ -0,0 +1,91 @@
+const std = @import("std");
+
+// Although this function looks imperative, note that its job is to
+// declaratively construct a build graph that will be executed by an external
+// runner.
+pub fn build(b: *std.Build) void {
+ // Standard target options allows the person running `zig build` to choose
+ // what target to build for. Here we do not override the defaults, which
+ // means any target is allowed, and the default is native. Other options
+ // for restricting supported target set are available.
+ const target = b.standardTargetOptions(.{});
+
+ // Standard optimization options allow the person running `zig build` to select
+ // between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall. Here we do not
+ // set a preferred release mode, allowing the user to decide how to optimize.
+ const optimize = b.standardOptimizeOption(.{});
+
+ const lib = b.addStaticLibrary(.{
+ .name = "stitch",
+ // In this case the main source file is merely a path, however, in more
+ // complicated build scripts, this could be a generated file.
+ .root_source_file = b.path("src/root.zig"),
+ .target = target,
+ .optimize = optimize,
+ });
+
+ // This declares intent for the library to be installed into the standard
+ // location when the user invokes the "install" step (the default step when
+ // running `zig build`).
+ b.installArtifact(lib);
+
+ const exe = b.addExecutable(.{
+ .name = "stitch",
+ .root_source_file = b.path("src/main.zig"),
+ .target = target,
+ .optimize = optimize,
+ });
+
+ // This declares intent for the executable to be installed into the
+ // standard location when the user invokes the "install" step (the default
+ // step when running `zig build`).
+ b.installArtifact(exe);
+
+ // This *creates* a Run step in the build graph, to be executed when another
+ // step is evaluated that depends on it. The next line below will establish
+ // such a dependency.
+ const run_cmd = b.addRunArtifact(exe);
+
+ // By making the run step depend on the install step, it will be run from the
+ // installation directory rather than directly from within the cache directory.
+ // This is not necessary, however, if the application depends on other installed
+ // files, this ensures they will be present and in the expected location.
+ run_cmd.step.dependOn(b.getInstallStep());
+
+ // This allows the user to pass arguments to the application in the build
+ // command itself, like this: `zig build run -- arg1 arg2 etc`
+ if (b.args) |args| {
+ run_cmd.addArgs(args);
+ }
+
+ // This creates a build step. It will be visible in the `zig build --help` menu,
+ // and can be selected like this: `zig build run`
+ // This will evaluate the `run` step rather than the default, which is "install".
+ const run_step = b.step("run", "Run the app");
+ run_step.dependOn(&run_cmd.step);
+
+ // Creates a step for unit testing. This only builds the test executable
+ // but does not run it.
+ const lib_unit_tests = b.addTest(.{
+ .root_source_file = b.path("src/root.zig"),
+ .target = target,
+ .optimize = optimize,
+ });
+
+ const run_lib_unit_tests = b.addRunArtifact(lib_unit_tests);
+
+ const exe_unit_tests = b.addTest(.{
+ .root_source_file = b.path("src/main.zig"),
+ .target = target,
+ .optimize = optimize,
+ });
+
+ const run_exe_unit_tests = b.addRunArtifact(exe_unit_tests);
+
+ // Similar to creating the run step earlier, this exposes a `test` step to
+ // the `zig build --help` menu, providing a way for the user to request
+ // running the unit tests.
+ const test_step = b.step("test", "Run unit tests");
+ test_step.dependOn(&run_lib_unit_tests.step);
+ test_step.dependOn(&run_exe_unit_tests.step);
+}
zig/build.zig.zon
@@ -0,0 +1,72 @@
+.{
+ // This is the default name used by packages depending on this one. For
+ // example, when a user runs `zig fetch --save <url>`, this field is used
+ // as the key in the `dependencies` table. Although the user can choose a
+ // different name, most users will stick with this provided value.
+ //
+ // It is redundant to include "zig" in this name because it is already
+ // within the Zig package namespace.
+ .name = "stitch",
+
+ // This is a [Semantic Version](https://semver.org/).
+ // In a future version of Zig it will be used for package deduplication.
+ .version = "0.0.0",
+
+ // This field is optional.
+ // This is currently advisory only; Zig does not yet do anything
+ // with this value.
+ //.minimum_zig_version = "0.11.0",
+
+ // This field is optional.
+ // Each dependency must either provide a `url` and `hash`, or a `path`.
+ // `zig build --fetch` can be used to fetch all dependencies of a package, recursively.
+ // Once all dependencies are fetched, `zig build` no longer requires
+ // internet connectivity.
+ .dependencies = .{
+ // See `zig fetch --save <url>` for a command-line interface for adding dependencies.
+ //.example = .{
+ // // When updating this field to a new URL, be sure to delete the corresponding
+ // // `hash`, otherwise you are communicating that you expect to find the old hash at
+ // // the new URL.
+ // .url = "https://example.com/foo.tar.gz",
+ //
+ // // This is computed from the file contents of the directory of files that is
+ // // obtained after fetching `url` and applying the inclusion rules given by
+ // // `paths`.
+ // //
+ // // This field is the source of truth; packages do not come from a `url`; they
+ // // come from a `hash`. `url` is just one of many possible mirrors for how to
+ // // obtain a package matching this `hash`.
+ // //
+ // // Uses the [multihash](https://multiformats.io/multihash/) format.
+ // .hash = "...",
+ //
+ // // When this is provided, the package is found in a directory relative to the
+ // // build root. In this case the package's hash is irrelevant and therefore not
+ // // computed. This field and `url` are mutually exclusive.
+ // .path = "foo",
+
+ // // When this is set to `true`, a package is declared to be lazily
+ // // fetched. This makes the dependency only get fetched if it is
+ // // actually used.
+ // .lazy = false,
+ //},
+ },
+
+ // Specifies the set of files and directories that are included in this package.
+ // Only files and directories listed here are included in the `hash` that
+ // is computed for this package. Only files listed here will remain on disk
+ // when using the zig package manager. As a rule of thumb, one should list
+ // files required for compilation plus any license(s).
+ // Paths are relative to the build root. Use the empty string (`""`) to refer to
+ // the build root itself.
+ // A directory listed here means that all files within, recursively, are included.
+ .paths = .{
+ "build.zig",
+ "build.zig.zon",
+ "src",
+ // For example...
+ //"LICENSE",
+ //"README.md",
+ },
+}