fix usage of wrong variable
[cumulus.git] / python / cumulus / store / sftp.py
1 # vim: ai ts=4 sts=4 et sw=4
2 #needed for python 2.5
3 from __future__ import with_statement
4
5 from paramiko import Transport, SFTPClient, RSAKey, DSSKey
6 from paramiko.config import SSHConfig
7 import paramiko.util
8 from cumulus.store import Store, type_patterns, NotFoundError
9 import os, os.path
10 import getpass
11 import re
12 import sys
13
14
15 class SSHHostConfig(dict):
16     def __init__(self, hostname, user = None, filename = None):
17         #set defaults
18         if filename == None:
19             filename = os.path.expanduser('~/.ssh/config')
20
21         #read config file
22         ssh_config = SSHConfig()
23         with open(filename) as config_file:
24             ssh_config.parse(config_file)
25
26         self.update(ssh_config.lookup(hostname))
27
28         self.defaults={'port': 22, 'user': getpass.getuser(), 'hostname': hostname, 'hostkeyalias': hostname}
29
30         if user != None:
31             self['user'] = user
32
33     def __getitem__(self, key):
34         if key in self:
35             return dict.__getitem__(self,key)
36         elif key == 'hostkeyalias' and 'hostname' in self:
37             return dict.__getitem__(self,'hostname')
38         else:
39             return self.defaults[key]
40
41
42 class SFTPStore(Store):
43     """implements the sftp:// storage backend
44
45         configuration via openssh/sftp style urls and
46         .ssh/config files
47
48         does not support password authentication or password
49         protected authentication keys"""
50     def __init__(self, url, **kw):
51         if self.netloc.find('@') != -1:
52             user, self.netloc = self.netloc.split('@')
53         else:
54             user = None
55
56         self.config = SSHHostConfig(self.netloc, user)
57
58         host_keys = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
59         try:
60             self.hostkey = host_keys[self.config['hostkeyalias']].values()[0]
61         except:
62             print str(self.config)
63             raise
64
65
66         if(self.config.has_key('identityfile')):
67             key_file = os.path.expanduser(self.config['identityfile'])
68             #not really nice but i don't see a cleaner way atm...
69             try:
70                 self.auth_key = RSAKey (key_file)
71             except SSHException, e:
72                 if e.message == 'Unable to parse file':
73                     self.auth_key = DSAKey (key_file)
74                 else:
75                     raise
76         else:
77             filename = os.path.expanduser('~/.ssh/id_rsa')
78             if os.path.exists(filename):
79                 self.auth_key = RSAKey(filename)
80             else:
81                 filename = os.path.expanduser('~/.ssh/id_dsa')
82                 if (os.path.exists(filename)):
83                     self.auth_key = DSSKey (filename)
84
85         self.__connect()
86
87     def __connect(self):
88         self.t = Transport((self.config['hostname'], self.config['port']))
89         self.t.connect(username = self.config['user'], pkey = self.auth_key)
90         self.client = SFTPClient.from_transport(self.t)
91         self.client.chdir(self.path)
92
93     def __build_fn(self, name):
94         return "%s/%s" % (self.path,  name)
95
96     def list(self, type):
97         return filter(type_patterns[type].match, self.client.listdir(self.path))
98
99     def get(self, type, name):
100         return self.client.open(self.__build_fn(name), mode = 'rb')
101
102     def put(self, type, name, fp):
103         remote_file = self.client.open(self.__build_fn(name), mode = 'wb')
104         buf = fp.read(4096)
105         while (len(buf) > 0):
106             remote_file.write(buf)
107             buf = fp.read(4096)
108         remote_file.close()
109
110     def delete(self, type, name):
111         self.client.remove(self.__build_fn(name))
112
113     def stat(self, type, name):
114         try:
115             stat = self.client.stat(self.__build_fn(name))
116             return {'size': stat.st_size}
117         except IOError:
118             raise NotFoundError
119
120     def close(self):
121         """connection has to be explicitly closed, otherwise
122             it will hold the process running idefinitly"""
123         self.client.close()
124         self.t.close()
125
126 Store = SFTPStore