repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / tunkit
Eric Bower  ·  2025-05-25

web-handler.go

 1package tunkit
 2
 3import (
 4	"fmt"
 5	"log/slog"
 6	"net"
 7	"os"
 8
 9	"github.com/picosh/pico/pkg/pssh"
10)
11
12type ctxAddressKey struct{}
13
14func getAddressCtx(ctx *pssh.SSHServerConnSession) (string, error) {
15	address, ok := ctx.Value(ctxAddressKey{}).(string)
16	if address == "" || !ok {
17		return address, fmt.Errorf("address not set on `*pssh.SSHServerConnSession()` for connection")
18	}
19	return address, nil
20}
21func setAddressCtx(ctx *pssh.SSHServerConnSession, address string) {
22	ctx.SetValue(ctxAddressKey{}, address)
23}
24
25type WebTunnelHandler struct {
26	HttpHandler HttpHandlerFn
27	Logger      *slog.Logger
28}
29
30func NewWebTunnelHandler(handler HttpHandlerFn, logger *slog.Logger) *WebTunnelHandler {
31	return &WebTunnelHandler{
32		HttpHandler: handler,
33		Logger:      logger,
34	}
35}
36
37func (wt *WebTunnelHandler) GetLogger() *slog.Logger {
38	return wt.Logger
39}
40
41func (wt *WebTunnelHandler) GetHttpHandler() HttpHandlerFn {
42	return wt.HttpHandler
43}
44
45func (wt *WebTunnelHandler) Close(ctx *pssh.SSHServerConnSession) error {
46	listener, err := getListenerCtx(ctx)
47	if err != nil {
48		return err
49	}
50
51	if listener != nil {
52		_ = listener.Close()
53		setListenerCtx(ctx, nil)
54	}
55
56	return nil
57}
58
59func (wt *WebTunnelHandler) CreateListener(ctx *pssh.SSHServerConnSession) (net.Listener, error) {
60	tempFile, err := os.CreateTemp("", "")
61	if err != nil {
62		return nil, err
63	}
64
65	err = tempFile.Close()
66	if err != nil {
67		return nil, err
68	}
69	address := tempFile.Name()
70	err = os.Remove(address)
71	if err != nil {
72		return nil, err
73	}
74
75	connListener, err := net.Listen("unix", address)
76	if err != nil {
77		return nil, err
78	}
79	setAddressCtx(ctx, address)
80	setListenerCtx(ctx, connListener)
81
82	return connListener, nil
83}
84
85func (wt *WebTunnelHandler) CreateConn(ctx *pssh.SSHServerConnSession) (net.Conn, error) {
86	_, err := httpServe(wt, ctx, wt.GetLogger())
87	if err != nil {
88		wt.GetLogger().Info("unable to create listener", "err", err)
89		return nil, err
90	}
91
92	address, err := getAddressCtx(ctx)
93	if err != nil {
94		return nil, err
95	}
96
97	return net.Dial("unix", address)
98}