sftp.go (8321B) - raw
1 // SFTP server for users with Flounder accounts 2 // A lot of this is copied from SFTPGo, but simplified for our use case. 3 package main 4 5 import ( 6 "crypto/rand" 7 "crypto/rsa" 8 "crypto/x509" 9 "encoding/pem" 10 "fmt" 11 "io" 12 "io/ioutil" 13 "log" 14 "net" 15 "os" 16 "path" 17 "runtime/debug" 18 "time" 19 20 "github.com/pkg/sftp" 21 "golang.org/x/crypto/ssh" 22 ) 23 24 type Connection struct { 25 User string 26 } 27 28 func (con *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { 29 // check user perms -- cant read others hidden files 30 userDir := getUserDirectory(con.User) // NOTE -- not cross platform 31 fullpath := path.Join(userDir, cleanPath(request.Filepath)) 32 f, err := os.OpenFile(fullpath, os.O_RDONLY, 0) 33 if err != nil { 34 return nil, err 35 } 36 return f, nil 37 } 38 39 func (conn *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) { 40 // check user perms -- cant write others files 41 userDir := getUserDirectory(conn.User) // NOTE -- not cross platform 42 fullpath := path.Join(userDir, cleanPath(request.Filepath)) 43 err := checkIfValidFile(conn.User, fullpath, []byte{}) 44 if err != nil { 45 return nil, err 46 } 47 f, err := os.OpenFile(fullpath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666) 48 if err != nil { 49 return nil, err 50 } 51 return f, nil 52 } 53 54 func (conn *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) { 55 userDir := getUserDirectory(conn.User) // NOTE -- not cross platform 56 fullpath := path.Join(userDir, cleanPath(request.Filepath)) 57 switch request.Method { 58 case "List": 59 f, err := os.Open(fullpath) 60 if err != nil { 61 return nil, err 62 } 63 fileInfo, err := f.Readdir(-1) 64 if err != nil { 65 return nil, err 66 } 67 return listerat(fileInfo), nil 68 case "Stat": 69 stat, err := os.Stat(fullpath) 70 if err != nil { 71 return nil, err 72 } 73 return listerat([]os.FileInfo{stat}), nil 74 } 75 return nil, fmt.Errorf("Invalid command") 76 } 77 78 func (conn *Connection) Filecmd(request *sftp.Request) error { 79 // remove, rename, setstat? find out 80 userDir := getUserDirectory(conn.User) // NOTE -- not cross platform 81 fullpath := path.Join(userDir, cleanPath(request.Filepath)) 82 targetPath := path.Join(userDir, cleanPath(request.Target)) 83 var err error 84 switch request.Method { 85 case "Remove": 86 err = os.Remove(fullpath) 87 case "Mkdir": 88 err = os.Mkdir(fullpath, 0755) 89 case "Rename": 90 err := checkIfValidFile(conn.User, targetPath, []byte{}) 91 if err != nil { 92 return err 93 } 94 err = os.Rename(fullpath, targetPath) 95 } 96 if err != nil { 97 return err 98 } 99 // Rename, Mkdir 100 return nil 101 } 102 103 // TODO hide hidden folders 104 // Users have write persm on their files, read perms on all 105 106 func buildHandlers(connection *Connection) sftp.Handlers { 107 return sftp.Handlers{ 108 connection, 109 connection, 110 connection, 111 connection, 112 } 113 } 114 115 // Based on example server code from golang.org/x/crypto/ssh and server_standalone 116 func runSFTPServer() { 117 if !c.EnableSFTP { 118 return 119 } 120 // An SSH server is represented by a ServerConfig, which holds 121 // certificate details and handles authentication of ServerConns. 122 config := &ssh.ServerConfig{ 123 PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { 124 // Should use constant-time compare (or better, salt+hash) in 125 // a production setting. 126 if isOkUsername(c.User()) != nil { // extra check, probably unnecessary 127 return nil, fmt.Errorf("Invalid username") 128 } 129 _, _, err := checkLogin(c.User(), string(pass)) 130 // TODO maybe give admin extra permissions? 131 if err != nil { 132 return nil, fmt.Errorf("password rejected for %q", c.User()) 133 } else { 134 log.Printf("Login: %s\n", c.User()) 135 return nil, nil 136 } 137 }, 138 } 139 140 // TODO generate key automatically 141 if _, err := os.Stat(c.HostKeyPath); os.IsNotExist(err) { 142 // path/to/whatever does not exist 143 log.Println("Host key not found, generating host key") 144 err := GenerateRSAKeys() 145 if err != nil { 146 log.Fatal(err) 147 } 148 } 149 150 privateBytes, err := ioutil.ReadFile(c.HostKeyPath) 151 if err != nil { 152 log.Fatal("Failed to load private key", err) 153 } 154 155 private, err := ssh.ParsePrivateKey(privateBytes) 156 if err != nil { 157 log.Fatal("Failed to parse private key", err) 158 } 159 160 config.AddHostKey(private) 161 162 listener, err := net.Listen("tcp", "0.0.0.0:2024") 163 if err != nil { 164 log.Fatal("failed to listen for connection", err) 165 } 166 167 log.Printf("SFTP server listening on %v\n", listener.Addr()) 168 169 for { 170 conn, err := listener.Accept() 171 if err != nil { 172 log.Fatal(err) 173 } 174 go acceptInboundConnection(conn, config) 175 } 176 } 177 178 func acceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) { 179 defer func() { 180 if r := recover(); r != nil { 181 log.Printf("panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack())) 182 } 183 }() 184 ipAddr := GetIPFromRemoteAddress(conn.RemoteAddr().String()) 185 log.Println("Request from IP " + ipAddr) 186 limiter := getVisitor(ipAddr) 187 if limiter.Allow() == false { 188 conn.Close() 189 return 190 } 191 // Before beginning a handshake must be performed on the incoming net.Conn 192 // we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH 193 conn.SetDeadline(time.Now().Add(2 * time.Minute)) 194 195 // Before use, a handshake must be performed on the incoming net.Conn. 196 sconn, chans, reqs, err := ssh.NewServerConn(conn, config) 197 if err != nil { 198 log.Printf("failed to accept an incoming connection: %v", err) 199 return 200 } 201 log.Println("login detected:", sconn.User()) 202 fmt.Fprintf(os.Stderr, "SSH server established\n") 203 // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on 204 conn.SetDeadline(time.Time{}) 205 206 defer conn.Close() 207 208 // The incoming Request channel must be serviced. 209 go ssh.DiscardRequests(reqs) 210 211 // Service the incoming Channel channel. 212 channelCounter := int64(0) 213 for newChannel := range chans { 214 // Channels have a type, depending on the application level 215 // protocol intended. In the case of an SFTP session, this is "subsystem" 216 // with a payload string of "<length=4>sftp" 217 fmt.Fprintf(os.Stderr, "Incoming channel: %s\n", newChannel.ChannelType()) 218 if newChannel.ChannelType() != "session" { 219 newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") 220 fmt.Fprintf(os.Stderr, "Unknown channel type: %s\n", newChannel.ChannelType()) 221 continue 222 } 223 channel, requests, err := newChannel.Accept() 224 if err != nil { 225 log.Println("could not accept channel.", err) 226 continue 227 } 228 229 channelCounter++ 230 fmt.Fprintf(os.Stderr, "Channel accepted\n") 231 232 // Sessions have out-of-band requests such as "shell", 233 // "pty-req" and "env". Here we handle only the 234 // "subsystem" request. 235 go func(in <-chan *ssh.Request) { 236 for req := range in { 237 fmt.Fprintf(os.Stderr, "Request: %v\n", req.Type) 238 ok := false 239 switch req.Type { 240 case "subsystem": 241 fmt.Fprintf(os.Stderr, "Subsystem: %s\n", req.Payload[4:]) 242 if string(req.Payload[4:]) == "sftp" { 243 ok = true 244 } 245 } 246 fmt.Fprintf(os.Stderr, " - accepted: %v\n", ok) 247 req.Reply(ok, nil) 248 } 249 }(requests) 250 connection := Connection{sconn.User()} 251 root := buildHandlers(&connection) 252 server := sftp.NewRequestServer(channel, root) 253 if err := server.Serve(); err == io.EOF { 254 server.Close() 255 log.Println("sftp client exited session.") 256 } else if err != nil { 257 log.Println("sftp server completed with error:", err) 258 return 259 } 260 } 261 } 262 263 // GenerateRSAKeys generate rsa private and public keys and write the 264 // private key to specified file and the public key to the specified 265 // file adding the .pub suffix 266 func GenerateRSAKeys() error { 267 key, err := rsa.GenerateKey(rand.Reader, 4096) 268 if err != nil { 269 return err 270 } 271 272 o, err := os.OpenFile(c.HostKeyPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) 273 if err != nil { 274 return err 275 } 276 defer o.Close() 277 278 priv := &pem.Block{ 279 Type: "RSA PRIVATE KEY", 280 Bytes: x509.MarshalPKCS1PrivateKey(key), 281 } 282 283 if err := pem.Encode(o, priv); err != nil { 284 return err 285 } 286 287 pub, err := ssh.NewPublicKey(&key.PublicKey) 288 if err != nil { 289 return err 290 } 291 return ioutil.WriteFile(c.HostKeyPath+".pub", ssh.MarshalAuthorizedKey(pub), 0600) 292 } 293 294 type listerat []os.FileInfo 295 296 // Modeled after strings.Reader's ReadAt() implementation 297 func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) { 298 var n int 299 if offset >= int64(len(f)) { 300 return 0, io.EOF 301 } 302 n = copy(ls, f[offset:]) 303 if n < len(ls) { 304 return n, io.EOF 305 } 306 return n, nil 307 }