repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / tunkit
Antonio Mika  ·  2025-03-12

web-handler.go

 1package tunkit
 2
 3import (
 4	"fmt"
 5	"log/slog"
 6	"net"
 7	"os"
 8
 9	"github.com/picosh/pico/pkg/pssh"
10)
11
12type ctxAddressKey struct{}
13
14func getAddressCtx(ctx *pssh.SSHServerConnSession) (string, error) {
15	address, ok := ctx.Value(ctxAddressKey{}).(string)
16	if address == "" || !ok {
17		return address, fmt.Errorf("address not set on `*pssh.SSHServerConnSession()` for connection")
18	}
19	return address, nil
20}
21func setAddressCtx(ctx *pssh.SSHServerConnSession, address string) {
22	ctx.SetValue(ctxAddressKey{}, address)
23}
24
25type WebTunnelHandler struct {
26	HttpHandler HttpHandlerFn
27	Logger      *slog.Logger
28}
29
30func NewWebTunnelHandler(handler HttpHandlerFn, logger *slog.Logger) *WebTunnelHandler {
31	return &WebTunnelHandler{
32		HttpHandler: handler,
33		Logger:      logger,
34	}
35}
36
37func (wt *WebTunnelHandler) GetLogger() *slog.Logger {
38	return wt.Logger
39}
40
41func (wt *WebTunnelHandler) GetHttpHandler() HttpHandlerFn {
42	return wt.HttpHandler
43}
44
45func (wt *WebTunnelHandler) Close(ctx *pssh.SSHServerConnSession) error {
46	listener, err := getListenerCtx(ctx)
47	if err != nil {
48		return err
49	}
50
51	if listener != nil {
52		_ = listener.Close()
53		setListenerCtx(ctx, nil)
54	}
55
56	return nil
57}
58
59func (wt *WebTunnelHandler) CreateListener(ctx *pssh.SSHServerConnSession) (net.Listener, error) {
60	tempFile, err := os.CreateTemp("", "")
61	if err != nil {
62		return nil, err
63	}
64
65	tempFile.Close()
66	address := tempFile.Name()
67	os.Remove(address)
68
69	connListener, err := net.Listen("unix", address)
70	if err != nil {
71		return nil, err
72	}
73	setAddressCtx(ctx, address)
74	setListenerCtx(ctx, connListener)
75
76	return connListener, nil
77}
78
79func (wt *WebTunnelHandler) CreateConn(ctx *pssh.SSHServerConnSession) (net.Conn, error) {
80	_, err := httpServe(wt, ctx, wt.GetLogger())
81	if err != nil {
82		wt.GetLogger().Info("unable to create listener", "err", err)
83		return nil, err
84	}
85
86	address, err := getAddressCtx(ctx)
87	if err != nil {
88		return nil, err
89	}
90
91	return net.Dial("unix", address)
92}