Antonio Mika
·
2025-03-12
scp.go
1package scp
2
3import (
4 "fmt"
5
6 "github.com/picosh/pico/pkg/pssh"
7 "github.com/picosh/pico/pkg/send/utils"
8)
9
10func Middleware(writeHandler utils.CopyFromClientHandler) pssh.SSHServerMiddleware {
11 return func(sshHandler pssh.SSHServerHandler) pssh.SSHServerHandler {
12 return func(session *pssh.SSHServerConnSession) error {
13 cmd := session.Command()
14 if len(cmd) == 0 || cmd[0] != "scp" {
15 return sshHandler(session)
16 }
17
18 logger := writeHandler.GetLogger(session).With(
19 "scp", true,
20 "cmd", cmd,
21 )
22
23 defer func() {
24 if r := recover(); r != nil {
25 logger.Error("error running scp middleware", "err", r)
26 _, _ = session.Stderr().Write([]byte("error running scp middleware, check the flags you are using\r\n"))
27 }
28 }()
29
30 info := GetInfo(cmd)
31 if !info.Ok {
32 return sshHandler(session)
33 }
34
35 var err error
36
37 switch info.Op {
38 case OpCopyToClient:
39 if writeHandler == nil {
40 err = fmt.Errorf("no handler provided for scp -t")
41 break
42 }
43 err = copyToClient(session, info, writeHandler)
44 case OpCopyFromClient:
45 if writeHandler == nil {
46 err = fmt.Errorf("no handler provided for scp -t")
47 break
48 }
49 err = copyFromClient(session, info, writeHandler)
50 }
51 if err != nil {
52 utils.ErrorHandler(session, err)
53 }
54
55 return err
56 }
57 }
58}
59
60// Op defines which kind of SCP Operation is going on.
61type Op byte
62
63const (
64 // OpCopyToClient is when a file is being copied from the server to the client.
65 OpCopyToClient Op = 'f'
66
67 // OpCopyFromClient is when a file is being copied from the client into the server.
68 OpCopyFromClient Op = 't'
69)
70
71// Info provides some information about the current SCP Operation.
72type Info struct {
73 // Ok is true if the current session is a SCP.
74 Ok bool
75
76 // Recursice is true if its a recursive SCP.
77 Recursive bool
78
79 // Path is the server path of the scp operation.
80 Path string
81
82 // Op is the SCP operation kind.
83 Op Op
84}
85
86func GetInfo(cmd []string) Info {
87 info := Info{}
88 if len(cmd) == 0 || cmd[0] != "scp" {
89 return info
90 }
91
92 for i, p := range cmd {
93 switch p {
94 case "-r":
95 info.Recursive = true
96 case "-f":
97 info.Op = OpCopyToClient
98 info.Path = cmd[i+1]
99 case "-t":
100 info.Op = OpCopyFromClient
101 info.Path = cmd[i+1]
102 }
103 }
104
105 info.Ok = true
106 return info
107}