code cleanup & regression fix
[cumulus.git] / python / cumulus / store / sftp.py
1 # vim: ai ts=4 sts=4 et sw=4
2 #needed for python 2.5
3 from __future__ import with_statement
4
5 from paramiko import Transport, SFTPClient, RSAKey, DSSKey
6 from paramiko.config import SSHConfig
7 import paramiko.util
8 from cumulus.store import Store, type_patterns, NotFoundError
9 import os, os.path
10 import getpass
11 import re
12 import sys
13
14
15 class SSHHostConfig(dict):
16     def __init__(self, hostname, user = None, filename = None):
17         #set defaults
18         if filename == None:
19             filename = os.path.expanduser('~/.ssh/config')
20
21         #read config file
22         ssh_config = SSHConfig()
23         with open(filename) as config_file:
24             ssh_config.parse(config_file)
25
26         self.update(ssh_config.lookup(hostname))
27
28         self.defaults={'port': 22, 'user': getpass.getuser(), 'hostname': hostname, 'hostkeyalias': hostname}
29
30         if user != None:
31             self['user'] = user
32
33     def __getitem__(self, key):
34         if key in self:
35             return dict.__getitem__(self,key)
36         elif key == 'hostkeyalias' and 'hostname' in self:
37             return dict.__getitem__(self,'hostname')
38         else:
39             return self.defaults[key]
40
41
42 class SFTPStore(Store):
43     """implements the sftp:// storage backend
44
45         configuration via openssh/sftp style urls and
46         .ssh/config files
47
48         does not support password authentication or password
49         protected authentication keys"""
50     def __init__(self, url, **kw):
51         if self.path.find('@') != -1:
52             user, self.netloc = self.netloc.split('@')
53         else:
54             user = None
55
56 #        if self.netloc.find(':') != -1:
57 #            host, self.path = self.netloc.split(':')
58 #        else:
59 #            host, self.path = self.netloc.split('/', 1)
60
61         self.config = SSHHostConfig(self.netloc, user)
62
63         host_keys = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
64         try:
65             self.hostkey = host_keys[self.config['hostkeyalias']].values()[0]
66         except:
67             print str(self.config)
68             raise
69
70
71         if(self.config.has_key('identityfile')):
72             key_file = os.path.expanduser(self.config['identityfile'])
73             #not really nice but i don't see a cleaner way atm...
74             try:
75                 self.auth_key = RSAKey (filename = key_file)
76             except SSHException, e:
77                 if e.message == 'Unable to parse file':
78                     self.auth_key = DSAKey (filename = key_file)
79                 else:
80                     raise
81         else:
82             filename = os.path.expanduser('~/.ssh/id_rsa')
83             if os.path.exists(filename):
84                 self.auth_key = RSAKey(filename = filename)
85             else:
86                 filename = os.path.expanduser('~/.ssh/id_dsa')
87                 if (os.path.exists(filename)):
88                     self.auth_key = DSSKey (filename = filename)
89
90         self.__connect()
91
92     def __connect(self):
93         self.t = Transport((self.config['hostname'], self.config['port']))
94         self.t.connect(username = self.config['user'], pkey = self.auth_key)
95         self.client = SFTPClient.from_transport(self.t)
96         self.client.chdir(self.path)
97
98     def __build_fn(self, name):
99         return "%s/%s" % (self.path,  name)
100
101     def list(self, type):
102         return filter(type_patterns[type].match, self.client.listdir(self.path))
103
104     def get(self, type, name):
105         return self.client.open(filename = self.__build_fn(name), mode = 'rb')
106
107     def put(self, type, name, fp):
108         remote_file = self.client.open(filename = self.__build_fn(name), mode = 'wb')
109         buf = fp.read(4096)
110         while (len(buf) > 0):
111             remote_file.write(buf)
112             buf = fp.read(4096)
113         remote_file.close()
114
115     def delete(self, type, name):
116         self.client.remove(filename = self.__build_fn(name))
117
118     def stat(self, type, name):
119         stat = self.client.stat(filename = self.__build_fn(name))
120         return {'size': stat.st_size}
121
122     def close(self):
123         """connection has to be explicitly closed, otherwise
124             it will hold the process running idefinitly"""
125         self.client.close()
126         self.t.close()
127
128 Store = SFTPStore