repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / tunkit
Eric Bower  ·  2025-04-23

ptun.go

  1package tunkit
  2
  3import (
  4	"context"
  5	"errors"
  6	"io"
  7	"log/slog"
  8	"net"
  9	"sync"
 10
 11	"github.com/picosh/pico/pkg/pssh"
 12	"golang.org/x/crypto/ssh"
 13)
 14
 15type forwardedTCPPayload struct {
 16	Addr       string
 17	Port       uint32
 18	OriginAddr string
 19	OriginPort uint32
 20}
 21
 22type Tunnel interface {
 23	CreateConn(ctx *pssh.SSHServerConnSession) (net.Conn, error)
 24	GetLogger() *slog.Logger
 25	Close(ctx *pssh.SSHServerConnSession) error
 26}
 27
 28func LocalForwardHandler(handler Tunnel) pssh.SSHServerChannelMiddleware {
 29	return func(newChan ssh.NewChannel, sc *pssh.SSHServerConn) error {
 30		check := &forwardedTCPPayload{}
 31		err := ssh.Unmarshal(newChan.ExtraData(), check)
 32		logger := handler.GetLogger()
 33		if err != nil {
 34			logger.Error(
 35				"error unmarshaling information",
 36				"err", err,
 37			)
 38			return err
 39		}
 40
 41		log := logger.With(
 42			"addr", check.Addr,
 43			"port", check.Port,
 44			"origAddr", check.OriginAddr,
 45			"origPort", check.OriginPort,
 46		)
 47		log.Info("local forward request")
 48
 49		ch, reqs, err := newChan.Accept()
 50		if err != nil {
 51			log.Error("cannot accept new channel", "err", err)
 52			return err
 53		}
 54
 55		origCtx, cancel := context.WithCancel(context.Background())
 56		ctx := &pssh.SSHServerConnSession{
 57			Channel:       ch,
 58			SSHServerConn: sc,
 59			Ctx:           origCtx,
 60			CancelFunc:    cancel,
 61		}
 62
 63		go ssh.DiscardRequests(reqs)
 64
 65		go func() {
 66			downConn, err := handler.CreateConn(ctx)
 67			if err != nil {
 68				log.Error("unable to connect to conn", "err", err)
 69				ch.Close()
 70				return
 71			}
 72			defer downConn.Close()
 73
 74			var wg sync.WaitGroup
 75			wg.Add(2)
 76
 77			go func() {
 78				defer wg.Done()
 79				defer func() {
 80					_ = ch.CloseWrite()
 81				}()
 82				defer downConn.Close()
 83				_, err := io.Copy(ch, downConn)
 84				if err != nil {
 85					if !errors.Is(err, net.ErrClosed) {
 86						log.Error("io copy", "err", err)
 87					}
 88				}
 89			}()
 90			go func() {
 91				defer wg.Done()
 92				defer ch.Close()
 93				defer downConn.Close()
 94				_, err := io.Copy(downConn, ch)
 95				if err != nil {
 96					if !errors.Is(err, net.ErrClosed) {
 97						log.Error("io copy", "err", err)
 98					}
 99				}
100			}()
101
102			wg.Wait()
103		}()
104
105		<-ctx.Done()
106		err = handler.Close(ctx)
107		if err != nil {
108			log.Error("tunnel handler error", "err", err)
109		}
110		return err
111	}
112}