Antonio Mika
·
2025-03-12
web-handler.go
1package tunkit
2
3import (
4 "fmt"
5 "log/slog"
6 "net"
7 "os"
8
9 "github.com/picosh/pico/pkg/pssh"
10)
11
12type ctxAddressKey struct{}
13
14func getAddressCtx(ctx *pssh.SSHServerConnSession) (string, error) {
15 address, ok := ctx.Value(ctxAddressKey{}).(string)
16 if address == "" || !ok {
17 return address, fmt.Errorf("address not set on `*pssh.SSHServerConnSession()` for connection")
18 }
19 return address, nil
20}
21func setAddressCtx(ctx *pssh.SSHServerConnSession, address string) {
22 ctx.SetValue(ctxAddressKey{}, address)
23}
24
25type WebTunnelHandler struct {
26 HttpHandler HttpHandlerFn
27 Logger *slog.Logger
28}
29
30func NewWebTunnelHandler(handler HttpHandlerFn, logger *slog.Logger) *WebTunnelHandler {
31 return &WebTunnelHandler{
32 HttpHandler: handler,
33 Logger: logger,
34 }
35}
36
37func (wt *WebTunnelHandler) GetLogger() *slog.Logger {
38 return wt.Logger
39}
40
41func (wt *WebTunnelHandler) GetHttpHandler() HttpHandlerFn {
42 return wt.HttpHandler
43}
44
45func (wt *WebTunnelHandler) Close(ctx *pssh.SSHServerConnSession) error {
46 listener, err := getListenerCtx(ctx)
47 if err != nil {
48 return err
49 }
50
51 if listener != nil {
52 _ = listener.Close()
53 setListenerCtx(ctx, nil)
54 }
55
56 return nil
57}
58
59func (wt *WebTunnelHandler) CreateListener(ctx *pssh.SSHServerConnSession) (net.Listener, error) {
60 tempFile, err := os.CreateTemp("", "")
61 if err != nil {
62 return nil, err
63 }
64
65 tempFile.Close()
66 address := tempFile.Name()
67 os.Remove(address)
68
69 connListener, err := net.Listen("unix", address)
70 if err != nil {
71 return nil, err
72 }
73 setAddressCtx(ctx, address)
74 setListenerCtx(ctx, connListener)
75
76 return connListener, nil
77}
78
79func (wt *WebTunnelHandler) CreateConn(ctx *pssh.SSHServerConnSession) (net.Conn, error) {
80 _, err := httpServe(wt, ctx, wt.GetLogger())
81 if err != nil {
82 wt.GetLogger().Info("unable to create listener", "err", err)
83 return nil, err
84 }
85
86 address, err := getAddressCtx(ctx)
87 if err != nil {
88 return nil, err
89 }
90
91 return net.Dial("unix", address)
92}