flounder

A simple gemini site builder
git clone git://git.alexwennerberg.com/flounder
Log | Files | Refs | README | LICENSE

sftp.go (8321B) - raw


      1 // SFTP server for users with Flounder accounts
      2 // 	A lot of this is copied from SFTPGo, but simplified for our use case.
      3 package main
      4 
      5 import (
      6 	"crypto/rand"
      7 	"crypto/rsa"
      8 	"crypto/x509"
      9 	"encoding/pem"
     10 	"fmt"
     11 	"io"
     12 	"io/ioutil"
     13 	"log"
     14 	"net"
     15 	"os"
     16 	"path"
     17 	"runtime/debug"
     18 	"time"
     19 
     20 	"github.com/pkg/sftp"
     21 	"golang.org/x/crypto/ssh"
     22 )
     23 
     24 type Connection struct {
     25 	User string
     26 }
     27 
     28 func (con *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
     29 	// check user perms -- cant read others hidden files
     30 	userDir := getUserDirectory(con.User) // NOTE -- not cross platform
     31 	fullpath := path.Join(userDir, cleanPath(request.Filepath))
     32 	f, err := os.OpenFile(fullpath, os.O_RDONLY, 0)
     33 	if err != nil {
     34 		return nil, err
     35 	}
     36 	return f, nil
     37 }
     38 
     39 func (conn *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
     40 	// check user perms -- cant write others files
     41 	userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
     42 	fullpath := path.Join(userDir, cleanPath(request.Filepath))
     43 	err := checkIfValidFile(conn.User, fullpath, []byte{})
     44 	if err != nil {
     45 		return nil, err
     46 	}
     47 	f, err := os.OpenFile(fullpath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
     48 	if err != nil {
     49 		return nil, err
     50 	}
     51 	return f, nil
     52 }
     53 
     54 func (conn *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
     55 	userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
     56 	fullpath := path.Join(userDir, cleanPath(request.Filepath))
     57 	switch request.Method {
     58 	case "List":
     59 		f, err := os.Open(fullpath)
     60 		if err != nil {
     61 			return nil, err
     62 		}
     63 		fileInfo, err := f.Readdir(-1)
     64 		if err != nil {
     65 			return nil, err
     66 		}
     67 		return listerat(fileInfo), nil
     68 	case "Stat":
     69 		stat, err := os.Stat(fullpath)
     70 		if err != nil {
     71 			return nil, err
     72 		}
     73 		return listerat([]os.FileInfo{stat}), nil
     74 	}
     75 	return nil, fmt.Errorf("Invalid command")
     76 }
     77 
     78 func (conn *Connection) Filecmd(request *sftp.Request) error {
     79 	// remove, rename, setstat? find out
     80 	userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
     81 	fullpath := path.Join(userDir, cleanPath(request.Filepath))
     82 	targetPath := path.Join(userDir, cleanPath(request.Target))
     83 	var err error
     84 	switch request.Method {
     85 	case "Remove":
     86 		err = os.Remove(fullpath)
     87 	case "Mkdir":
     88 		err = os.Mkdir(fullpath, 0755)
     89 	case "Rename":
     90 		err := checkIfValidFile(conn.User, targetPath, []byte{})
     91 		if err != nil {
     92 			return err
     93 		}
     94 		err = os.Rename(fullpath, targetPath)
     95 	}
     96 	if err != nil {
     97 		return err
     98 	}
     99 	// Rename, Mkdir
    100 	return nil
    101 }
    102 
    103 // TODO hide hidden folders
    104 // Users have write persm on their files, read perms on all
    105 
    106 func buildHandlers(connection *Connection) sftp.Handlers {
    107 	return sftp.Handlers{
    108 		connection,
    109 		connection,
    110 		connection,
    111 		connection,
    112 	}
    113 }
    114 
    115 // Based on example server code from golang.org/x/crypto/ssh and server_standalone
    116 func runSFTPServer() {
    117 	if !c.EnableSFTP {
    118 		return
    119 	}
    120 	// An SSH server is represented by a ServerConfig, which holds
    121 	// certificate details and handles authentication of ServerConns.
    122 	config := &ssh.ServerConfig{
    123 		PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
    124 			// Should use constant-time compare (or better, salt+hash) in
    125 			// a production setting.
    126 			if isOkUsername(c.User()) != nil { // extra check, probably unnecessary
    127 				return nil, fmt.Errorf("Invalid username")
    128 			}
    129 			_, _, err := checkLogin(c.User(), string(pass))
    130 			// TODO maybe give admin extra permissions?
    131 			if err != nil {
    132 				return nil, fmt.Errorf("password rejected for %q", c.User())
    133 			} else {
    134 				log.Printf("Login: %s\n", c.User())
    135 				return nil, nil
    136 			}
    137 		},
    138 	}
    139 
    140 	// TODO generate key automatically
    141 	if _, err := os.Stat(c.HostKeyPath); os.IsNotExist(err) {
    142 		// path/to/whatever does not exist
    143 		log.Println("Host key not found, generating host key")
    144 		err := GenerateRSAKeys()
    145 		if err != nil {
    146 			log.Fatal(err)
    147 		}
    148 	}
    149 
    150 	privateBytes, err := ioutil.ReadFile(c.HostKeyPath)
    151 	if err != nil {
    152 		log.Fatal("Failed to load private key", err)
    153 	}
    154 
    155 	private, err := ssh.ParsePrivateKey(privateBytes)
    156 	if err != nil {
    157 		log.Fatal("Failed to parse private key", err)
    158 	}
    159 
    160 	config.AddHostKey(private)
    161 
    162 	listener, err := net.Listen("tcp", "0.0.0.0:2024")
    163 	if err != nil {
    164 		log.Fatal("failed to listen for connection", err)
    165 	}
    166 
    167 	log.Printf("SFTP server listening on %v\n", listener.Addr())
    168 
    169 	for {
    170 		conn, err := listener.Accept()
    171 		if err != nil {
    172 			log.Fatal(err)
    173 		}
    174 		go acceptInboundConnection(conn, config)
    175 	}
    176 }
    177 
    178 func acceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
    179 	defer func() {
    180 		if r := recover(); r != nil {
    181 			log.Printf("panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
    182 		}
    183 	}()
    184 	ipAddr := GetIPFromRemoteAddress(conn.RemoteAddr().String())
    185 	log.Println("Request from IP " + ipAddr)
    186 	limiter := getVisitor(ipAddr)
    187 	if limiter.Allow() == false {
    188 		conn.Close()
    189 		return
    190 	}
    191 	// Before beginning a handshake must be performed on the incoming net.Conn
    192 	// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
    193 	conn.SetDeadline(time.Now().Add(2 * time.Minute))
    194 
    195 	// Before use, a handshake must be performed on the incoming net.Conn.
    196 	sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
    197 	if err != nil {
    198 		log.Printf("failed to accept an incoming connection: %v", err)
    199 		return
    200 	}
    201 	log.Println("login detected:", sconn.User())
    202 	fmt.Fprintf(os.Stderr, "SSH server established\n")
    203 	// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
    204 	conn.SetDeadline(time.Time{})
    205 
    206 	defer conn.Close()
    207 
    208 	// The incoming Request channel must be serviced.
    209 	go ssh.DiscardRequests(reqs)
    210 
    211 	// Service the incoming Channel channel.
    212 	channelCounter := int64(0)
    213 	for newChannel := range chans {
    214 		// Channels have a type, depending on the application level
    215 		// protocol intended. In the case of an SFTP session, this is "subsystem"
    216 		// with a payload string of "<length=4>sftp"
    217 		fmt.Fprintf(os.Stderr, "Incoming channel: %s\n", newChannel.ChannelType())
    218 		if newChannel.ChannelType() != "session" {
    219 			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
    220 			fmt.Fprintf(os.Stderr, "Unknown channel type: %s\n", newChannel.ChannelType())
    221 			continue
    222 		}
    223 		channel, requests, err := newChannel.Accept()
    224 		if err != nil {
    225 			log.Println("could not accept channel.", err)
    226 			continue
    227 		}
    228 
    229 		channelCounter++
    230 		fmt.Fprintf(os.Stderr, "Channel accepted\n")
    231 
    232 		// Sessions have out-of-band requests such as "shell",
    233 		// "pty-req" and "env".  Here we handle only the
    234 		// "subsystem" request.
    235 		go func(in <-chan *ssh.Request) {
    236 			for req := range in {
    237 				fmt.Fprintf(os.Stderr, "Request: %v\n", req.Type)
    238 				ok := false
    239 				switch req.Type {
    240 				case "subsystem":
    241 					fmt.Fprintf(os.Stderr, "Subsystem: %s\n", req.Payload[4:])
    242 					if string(req.Payload[4:]) == "sftp" {
    243 						ok = true
    244 					}
    245 				}
    246 				fmt.Fprintf(os.Stderr, " - accepted: %v\n", ok)
    247 				req.Reply(ok, nil)
    248 			}
    249 		}(requests)
    250 		connection := Connection{sconn.User()}
    251 		root := buildHandlers(&connection)
    252 		server := sftp.NewRequestServer(channel, root)
    253 		if err := server.Serve(); err == io.EOF {
    254 			server.Close()
    255 			log.Println("sftp client exited session.")
    256 		} else if err != nil {
    257 			log.Println("sftp server completed with error:", err)
    258 			return
    259 		}
    260 	}
    261 }
    262 
    263 // GenerateRSAKeys generate rsa private and public keys and write the
    264 // private key to specified file and the public key to the specified
    265 // file adding the .pub suffix
    266 func GenerateRSAKeys() error {
    267 	key, err := rsa.GenerateKey(rand.Reader, 4096)
    268 	if err != nil {
    269 		return err
    270 	}
    271 
    272 	o, err := os.OpenFile(c.HostKeyPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
    273 	if err != nil {
    274 		return err
    275 	}
    276 	defer o.Close()
    277 
    278 	priv := &pem.Block{
    279 		Type:  "RSA PRIVATE KEY",
    280 		Bytes: x509.MarshalPKCS1PrivateKey(key),
    281 	}
    282 
    283 	if err := pem.Encode(o, priv); err != nil {
    284 		return err
    285 	}
    286 
    287 	pub, err := ssh.NewPublicKey(&key.PublicKey)
    288 	if err != nil {
    289 		return err
    290 	}
    291 	return ioutil.WriteFile(c.HostKeyPath+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
    292 }
    293 
    294 type listerat []os.FileInfo
    295 
    296 // Modeled after strings.Reader's ReadAt() implementation
    297 func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
    298 	var n int
    299 	if offset >= int64(len(f)) {
    300 		return 0, io.EOF
    301 	}
    302 	n = copy(ls, f[offset:])
    303 	if n < len(ls) {
    304 		return n, io.EOF
    305 	}
    306 	return n, nil
    307 }