repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / tunkit
Eric Bower  ·  2025-05-25

ptun.go

  1package tunkit
  2
  3import (
  4	"context"
  5	"errors"
  6	"io"
  7	"log/slog"
  8	"net"
  9	"sync"
 10
 11	"github.com/picosh/pico/pkg/pssh"
 12	"golang.org/x/crypto/ssh"
 13)
 14
 15type forwardedTCPPayload struct {
 16	Addr       string
 17	Port       uint32
 18	OriginAddr string
 19	OriginPort uint32
 20}
 21
 22type Tunnel interface {
 23	CreateConn(ctx *pssh.SSHServerConnSession) (net.Conn, error)
 24	GetLogger() *slog.Logger
 25	Close(ctx *pssh.SSHServerConnSession) error
 26}
 27
 28func LocalForwardHandler(handler Tunnel) pssh.SSHServerChannelMiddleware {
 29	return func(newChan ssh.NewChannel, sc *pssh.SSHServerConn) error {
 30		check := &forwardedTCPPayload{}
 31		err := ssh.Unmarshal(newChan.ExtraData(), check)
 32		logger := handler.GetLogger()
 33		if err != nil {
 34			logger.Error(
 35				"error unmarshaling information",
 36				"err", err,
 37			)
 38			return err
 39		}
 40
 41		log := logger.With(
 42			"addr", check.Addr,
 43			"port", check.Port,
 44			"origAddr", check.OriginAddr,
 45			"origPort", check.OriginPort,
 46		)
 47		log.Info("local forward request")
 48
 49		ch, reqs, err := newChan.Accept()
 50		if err != nil {
 51			log.Error("cannot accept new channel", "err", err)
 52			return err
 53		}
 54
 55		origCtx, cancel := context.WithCancel(context.Background())
 56		ctx := &pssh.SSHServerConnSession{
 57			Channel:       ch,
 58			SSHServerConn: sc,
 59			Ctx:           origCtx,
 60			CancelFunc:    cancel,
 61		}
 62
 63		go ssh.DiscardRequests(reqs)
 64
 65		go func() {
 66			downConn, err := handler.CreateConn(ctx)
 67			if err != nil {
 68				log.Error("unable to connect to conn", "err", err)
 69				_ = ch.Close()
 70				return
 71			}
 72			defer func() {
 73				_ = downConn.Close()
 74			}()
 75
 76			var wg sync.WaitGroup
 77			wg.Add(2)
 78
 79			go func() {
 80				defer wg.Done()
 81				defer func() {
 82					_ = ch.CloseWrite()
 83					_ = downConn.Close()
 84				}()
 85				_, err := io.Copy(ch, downConn)
 86				if err != nil {
 87					if !errors.Is(err, net.ErrClosed) {
 88						log.Error("io copy", "err", err)
 89					}
 90				}
 91			}()
 92			go func() {
 93				defer wg.Done()
 94				defer func() {
 95					_ = ch.Close()
 96					_ = downConn.Close()
 97				}()
 98				_, err := io.Copy(downConn, ch)
 99				if err != nil {
100					if !errors.Is(err, net.ErrClosed) {
101						log.Error("io copy", "err", err)
102					}
103				}
104			}()
105
106			wg.Wait()
107		}()
108
109		<-ctx.Done()
110		err = handler.Close(ctx)
111		if err != nil {
112			log.Error("tunnel handler error", "err", err)
113		}
114		return err
115	}
116}