Eric Bower
·
2025-05-25
ptun.go
1package tunkit
2
3import (
4 "context"
5 "errors"
6 "io"
7 "log/slog"
8 "net"
9 "sync"
10
11 "github.com/picosh/pico/pkg/pssh"
12 "golang.org/x/crypto/ssh"
13)
14
15type forwardedTCPPayload struct {
16 Addr string
17 Port uint32
18 OriginAddr string
19 OriginPort uint32
20}
21
22type Tunnel interface {
23 CreateConn(ctx *pssh.SSHServerConnSession) (net.Conn, error)
24 GetLogger() *slog.Logger
25 Close(ctx *pssh.SSHServerConnSession) error
26}
27
28func LocalForwardHandler(handler Tunnel) pssh.SSHServerChannelMiddleware {
29 return func(newChan ssh.NewChannel, sc *pssh.SSHServerConn) error {
30 check := &forwardedTCPPayload{}
31 err := ssh.Unmarshal(newChan.ExtraData(), check)
32 logger := handler.GetLogger()
33 if err != nil {
34 logger.Error(
35 "error unmarshaling information",
36 "err", err,
37 )
38 return err
39 }
40
41 log := logger.With(
42 "addr", check.Addr,
43 "port", check.Port,
44 "origAddr", check.OriginAddr,
45 "origPort", check.OriginPort,
46 )
47 log.Info("local forward request")
48
49 ch, reqs, err := newChan.Accept()
50 if err != nil {
51 log.Error("cannot accept new channel", "err", err)
52 return err
53 }
54
55 origCtx, cancel := context.WithCancel(context.Background())
56 ctx := &pssh.SSHServerConnSession{
57 Channel: ch,
58 SSHServerConn: sc,
59 Ctx: origCtx,
60 CancelFunc: cancel,
61 }
62
63 go ssh.DiscardRequests(reqs)
64
65 go func() {
66 downConn, err := handler.CreateConn(ctx)
67 if err != nil {
68 log.Error("unable to connect to conn", "err", err)
69 _ = ch.Close()
70 return
71 }
72 defer func() {
73 _ = downConn.Close()
74 }()
75
76 var wg sync.WaitGroup
77 wg.Add(2)
78
79 go func() {
80 defer wg.Done()
81 defer func() {
82 _ = ch.CloseWrite()
83 _ = downConn.Close()
84 }()
85 _, err := io.Copy(ch, downConn)
86 if err != nil {
87 if !errors.Is(err, net.ErrClosed) {
88 log.Error("io copy", "err", err)
89 }
90 }
91 }()
92 go func() {
93 defer wg.Done()
94 defer func() {
95 _ = ch.Close()
96 _ = downConn.Close()
97 }()
98 _, err := io.Copy(downConn, ch)
99 if err != nil {
100 if !errors.Is(err, net.ErrClosed) {
101 log.Error("io copy", "err", err)
102 }
103 }
104 }()
105
106 wg.Wait()
107 }()
108
109 <-ctx.Done()
110 err = handler.Close(ctx)
111 if err != nil {
112 log.Error("tunnel handler error", "err", err)
113 }
114 return err
115 }
116}