Antonio Mika
·
2025-03-12
ptun.go
1package tunkit
2
3import (
4 "errors"
5 "io"
6 "log/slog"
7 "net"
8 "sync"
9
10 "github.com/picosh/pico/pkg/pssh"
11 "golang.org/x/crypto/ssh"
12)
13
14type forwardedTCPPayload struct {
15 Addr string
16 Port uint32
17 OriginAddr string
18 OriginPort uint32
19}
20
21type Tunnel interface {
22 CreateConn(ctx *pssh.SSHServerConnSession) (net.Conn, error)
23 GetLogger() *slog.Logger
24 Close(ctx *pssh.SSHServerConnSession) error
25}
26
27func LocalForwardHandler(handler Tunnel) pssh.SSHServerChannelMiddleware {
28 return func(newChan ssh.NewChannel, sc *pssh.SSHServerConn) error {
29 check := &forwardedTCPPayload{}
30 err := ssh.Unmarshal(newChan.ExtraData(), check)
31 logger := handler.GetLogger()
32 if err != nil {
33 logger.Error(
34 "error unmarshaling information",
35 "err", err,
36 )
37 return err
38 }
39
40 log := logger.With(
41 "addr", check.Addr,
42 "port", check.Port,
43 "origAddr", check.OriginAddr,
44 "origPort", check.OriginPort,
45 )
46 log.Info("local forward request")
47
48 ch, reqs, err := newChan.Accept()
49 if err != nil {
50 log.Error("cannot accept new channel", "err", err)
51 return err
52 }
53
54 ctx := &pssh.SSHServerConnSession{
55 Channel: ch,
56 SSHServerConn: sc,
57 }
58
59 go ssh.DiscardRequests(reqs)
60
61 go func() {
62 downConn, err := handler.CreateConn(ctx)
63 if err != nil {
64 log.Error("unable to connect to conn", "err", err)
65 ch.Close()
66 return
67 }
68 defer downConn.Close()
69
70 var wg sync.WaitGroup
71 wg.Add(2)
72
73 go func() {
74 defer wg.Done()
75 defer func() {
76 _ = ch.CloseWrite()
77 }()
78 defer downConn.Close()
79 _, err := io.Copy(ch, downConn)
80 if err != nil {
81 if !errors.Is(err, net.ErrClosed) {
82 log.Error("io copy", "err", err)
83 }
84 }
85 }()
86 go func() {
87 defer wg.Done()
88 defer ch.Close()
89 defer downConn.Close()
90 _, err := io.Copy(downConn, ch)
91 if err != nil {
92 if !errors.Is(err, net.ErrClosed) {
93 log.Error("io copy", "err", err)
94 }
95 }
96 }()
97
98 wg.Wait()
99 }()
100
101 <-ctx.Done()
102 err = handler.Close(ctx)
103 if err != nil {
104 log.Error("tunnel handler error", "err", err)
105 }
106 return err
107 }
108}