repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / tunkit
Antonio Mika  ·  2025-03-12

ptun.go

  1package tunkit
  2
  3import (
  4	"errors"
  5	"io"
  6	"log/slog"
  7	"net"
  8	"sync"
  9
 10	"github.com/picosh/pico/pkg/pssh"
 11	"golang.org/x/crypto/ssh"
 12)
 13
 14type forwardedTCPPayload struct {
 15	Addr       string
 16	Port       uint32
 17	OriginAddr string
 18	OriginPort uint32
 19}
 20
 21type Tunnel interface {
 22	CreateConn(ctx *pssh.SSHServerConnSession) (net.Conn, error)
 23	GetLogger() *slog.Logger
 24	Close(ctx *pssh.SSHServerConnSession) error
 25}
 26
 27func LocalForwardHandler(handler Tunnel) pssh.SSHServerChannelMiddleware {
 28	return func(newChan ssh.NewChannel, sc *pssh.SSHServerConn) error {
 29		check := &forwardedTCPPayload{}
 30		err := ssh.Unmarshal(newChan.ExtraData(), check)
 31		logger := handler.GetLogger()
 32		if err != nil {
 33			logger.Error(
 34				"error unmarshaling information",
 35				"err", err,
 36			)
 37			return err
 38		}
 39
 40		log := logger.With(
 41			"addr", check.Addr,
 42			"port", check.Port,
 43			"origAddr", check.OriginAddr,
 44			"origPort", check.OriginPort,
 45		)
 46		log.Info("local forward request")
 47
 48		ch, reqs, err := newChan.Accept()
 49		if err != nil {
 50			log.Error("cannot accept new channel", "err", err)
 51			return err
 52		}
 53
 54		ctx := &pssh.SSHServerConnSession{
 55			Channel:       ch,
 56			SSHServerConn: sc,
 57		}
 58
 59		go ssh.DiscardRequests(reqs)
 60
 61		go func() {
 62			downConn, err := handler.CreateConn(ctx)
 63			if err != nil {
 64				log.Error("unable to connect to conn", "err", err)
 65				ch.Close()
 66				return
 67			}
 68			defer downConn.Close()
 69
 70			var wg sync.WaitGroup
 71			wg.Add(2)
 72
 73			go func() {
 74				defer wg.Done()
 75				defer func() {
 76					_ = ch.CloseWrite()
 77				}()
 78				defer downConn.Close()
 79				_, err := io.Copy(ch, downConn)
 80				if err != nil {
 81					if !errors.Is(err, net.ErrClosed) {
 82						log.Error("io copy", "err", err)
 83					}
 84				}
 85			}()
 86			go func() {
 87				defer wg.Done()
 88				defer ch.Close()
 89				defer downConn.Close()
 90				_, err := io.Copy(downConn, ch)
 91				if err != nil {
 92					if !errors.Is(err, net.ErrClosed) {
 93						log.Error("io copy", "err", err)
 94					}
 95				}
 96			}()
 97
 98			wg.Wait()
 99		}()
100
101		<-ctx.Done()
102		err = handler.Close(ctx)
103		if err != nil {
104			log.Error("tunnel handler error", "err", err)
105		}
106		return err
107	}
108}