Eric Bower
·
2025-04-23
ptun.go
1package tunkit
2
3import (
4 "context"
5 "errors"
6 "io"
7 "log/slog"
8 "net"
9 "sync"
10
11 "github.com/picosh/pico/pkg/pssh"
12 "golang.org/x/crypto/ssh"
13)
14
15type forwardedTCPPayload struct {
16 Addr string
17 Port uint32
18 OriginAddr string
19 OriginPort uint32
20}
21
22type Tunnel interface {
23 CreateConn(ctx *pssh.SSHServerConnSession) (net.Conn, error)
24 GetLogger() *slog.Logger
25 Close(ctx *pssh.SSHServerConnSession) error
26}
27
28func LocalForwardHandler(handler Tunnel) pssh.SSHServerChannelMiddleware {
29 return func(newChan ssh.NewChannel, sc *pssh.SSHServerConn) error {
30 check := &forwardedTCPPayload{}
31 err := ssh.Unmarshal(newChan.ExtraData(), check)
32 logger := handler.GetLogger()
33 if err != nil {
34 logger.Error(
35 "error unmarshaling information",
36 "err", err,
37 )
38 return err
39 }
40
41 log := logger.With(
42 "addr", check.Addr,
43 "port", check.Port,
44 "origAddr", check.OriginAddr,
45 "origPort", check.OriginPort,
46 )
47 log.Info("local forward request")
48
49 ch, reqs, err := newChan.Accept()
50 if err != nil {
51 log.Error("cannot accept new channel", "err", err)
52 return err
53 }
54
55 origCtx, cancel := context.WithCancel(context.Background())
56 ctx := &pssh.SSHServerConnSession{
57 Channel: ch,
58 SSHServerConn: sc,
59 Ctx: origCtx,
60 CancelFunc: cancel,
61 }
62
63 go ssh.DiscardRequests(reqs)
64
65 go func() {
66 downConn, err := handler.CreateConn(ctx)
67 if err != nil {
68 log.Error("unable to connect to conn", "err", err)
69 ch.Close()
70 return
71 }
72 defer downConn.Close()
73
74 var wg sync.WaitGroup
75 wg.Add(2)
76
77 go func() {
78 defer wg.Done()
79 defer func() {
80 _ = ch.CloseWrite()
81 }()
82 defer downConn.Close()
83 _, err := io.Copy(ch, downConn)
84 if err != nil {
85 if !errors.Is(err, net.ErrClosed) {
86 log.Error("io copy", "err", err)
87 }
88 }
89 }()
90 go func() {
91 defer wg.Done()
92 defer ch.Close()
93 defer downConn.Close()
94 _, err := io.Copy(downConn, ch)
95 if err != nil {
96 if !errors.Is(err, net.ErrClosed) {
97 log.Error("io copy", "err", err)
98 }
99 }
100 }()
101
102 wg.Wait()
103 }()
104
105 <-ctx.Done()
106 err = handler.Close(ctx)
107 if err != nil {
108 log.Error("tunnel handler error", "err", err)
109 }
110 return err
111 }
112}