puppet/ldap/files/dynldap.py

382 lines
11 KiB
Python

from copy import copy
import ldap
import ldap.modlist
import ldap.sasl
import os
import re
from subprocess import Popen, PIPE
from types import ListType
OPENLDAPCONF = (os.path.join(os.getenv('HOME', ''), '.ldaprc'),
'/usr/local/etc/openldap/ldap.conf',
'/etc/openldap/ldap.conf')
SAMBASECRETS = '/var/lib/samba/private/secrets.tdb'
def auth_callback(conn, username=None, password=None):
from os import getlogin
from getpass import getpass
if username is None:
try:
username = conn.searchone('(&(objectClass=posixAccount)(uid=%s))' \
% getlogin(), attrs=[]).dn
except IndexError:
if getlogin() == 'root':
username = 'cn=manager,%s' % conn.basedn
else:
username = raw_input('Enter bind DN: ')
if password is None:
password = getpass('Enter LDAP password: ')
return (username, password)
def auth_samba_callback(conn, username=None, password=None):
backend = get_samba_conf_param('passdb backend')
m = re.match('^ldapsam:"(.*)"$', backend)
if m is None:
raise ValueError('Samba not configured for LDAP backend')
binddn = get_samba_conf_param('ldap admin dn')
authpw = get_samba_secrets_param('SECRETS/LDAP_BIND_PW/%s' % binddn)
return (binddn, authpw)
def get_samba_conf_param(key, section = 'global'):
cmd = Popen(['testparm', '-s', '-v', '--section-name=%s' % section,
'-d', '0', '--parameter-name=%s' % key], stdout=PIPE,
stderr=PIPE)
cmd.wait()
for line in cmd.stdout.readlines():
m = re.match('^"(.*)"$', line)
if m is not None:
return m.group(1)
return line.strip()
for line in cmd.stderr.readlines():
if re.search('Parameter .* unknown for section [^ ]', line):
raise KeyError(line.strip())
def get_samba_secrets_param(key):
cmd = Popen(['tdbdump', SAMBASECRETS], stdout=PIPE)
cmd.wait()
for line in cmd.stdout.readlines():
line = line.strip()
if line == '{':
k = None
v = None
continue
elif line == '}':
if k == key:
return v
continue
line = re.match('^(key|data)\(\d+\) = "(.*)"$', line)
if line.group(1) == 'key':
k = line.group(2)
elif line.group(1) == 'data':
v = line.group(2)
if v[len(v)-3:] == '\\00':
v = v[0:len(v)-3]
raise KeyError()
def connect(uris = None, basedn = None):
if uris is None:
(uris, basedn) = parse_openldap_config()
return Connection(uris, basedn)
def parse_openldap_config():
for name in OPENLDAPCONF:
ret = [None, None]
try:
f = open(name, 'r')
except IOError:
continue
for line in f.readlines():
m = re.match('^\s*(URI|BASE)\s+(.*)', line)
if m is None:
continue
if m.group(1) == 'URI':
ret[0] = m.group(2).split()
elif m.group(1) == 'BASE':
ret[1] = m.group(2)
f.close()
if ret[0] is not None:
return tuple(ret)
raise ldap.SERVER_DOWN('Could not determine LDAP server address')
class LdapEntry:
def __init__(self, conn, entry, new=False):
for (attr, values) in entry[1].items():
self._set(attr, values)
if new:
self._orig = {}
else:
self._orig = copy(entry[1])
self._dn = entry[0]
self._conn = conn
def _del(self, key):
if key in self._attrs():
del(self.__dict__[key])
__delitem__ = _del
def _get(self, key):
if key in self._attrs():
keys = {}
for v in self.__dict__[key]:
keys[v] = True
value = keys.keys()
value.sort()
return value
return None
__getitem__ = _get
def __iter__(self):
return self._attrs().__iter__()
def __str__(self):
ret = 'dn: %s\n' % self.dn
for (key, values) in self.items():
for value in values:
ret += '%s: %s\n' % (key, value)
return ret
def _set(self, key, value):
if key[0] == '_':
self.__dict__[key] = value
elif value is None:
self._del(key)
else:
if not isinstance(value, ListType):
value = [ value ]
self.__dict__[key] = map(lambda x: str(x), value)
__setattr__ = _set
__setitem__ = __setattr__
@property
def dn(self):
return self._dn
@property
def modified(self):
if len(self.changes()) > 0:
return True
return False
def _attrs(self):
return map(lambda (k, v): k, self.items())
attrs = _attrs
def changes(self):
new = {}
for k in self._attrs():
new[k] = self._get(k)
return ldap.modlist.modifyModlist(self._orig, new)
def items(self):
return filter(lambda (k, v): k[0] != '_', self.__dict__.items())
def revert(self):
for k in self._attrs():
del(self.__dict__[k])
for (k, v) in self._orig.items():
self.__setattr__(k, v)
def save(self):
changes = self.changes()
if len(changes) > 0:
if self._orig == {}:
self._conn._conn.add_s(self.dn, self.items())
else:
self._conn._conn.modify_s(self.dn, changes)
self._orig = {}
for (k, v) in self.items():
self._orig[k] = v
class Connection:
def __init__(self, uri, basedn = None):
self.auth = (None, None)
self.reconnect(uri)
if basedn is None:
entry = self.get('', attrs=['namingContexts'])
basedn = entry.namingContexts[0]
self.basedn = basedn
self.set_option(ldap.OPT_REFERRALS, False)
def __del__(self):
self.unbind()
def reconnect(self, uris = None):
if uris is None:
uris = self.uri
elif not hasattr(uris, '__iter__'):
uris = [ uris ]
for uri in uris:
try:
self._conn = ldap.initialize(uri)
self.whoami()
if self.auth[0] is not None and self.auth[1] is not None:
self.bind(self.auth[0], self.auth[1])
self.uri = uri
return
except ldap.SERVER_DOWN:
continue
raise ldap.SERVER_DOWN({'desc': "Can't contact LDAP server"})
def bind(self, username=None, password=None, callback=auth_callback):
if username is None or password is None:
(username, password) = callback(self, username, password)
if password is '':
raise ldap.INVALID_CREDENTIALS(
{'desc': 'Empty passwords not allowed'})
self._conn.bind_s(username, password)
self.auth = (username, password)
def delete(self, dn):
return self._conn.delete_s(dn)
def get(self, dn, attrs = None):
return self.search(scope=ldap.SCOPE_BASE, basedn=dn, attrs=attrs)[0]
def get_option(self, option):
return self._conn.get_option(option)
def new(self, dn, entry = {}, **kwargs):
if not 'entryclass' in kwargs:
kwargs['entryclass'] = LdapEntry
return kwargs['entryclass'](self, [dn, entry], True)
def passwd(self, newpw, **kwargs):
if 'user' not in kwargs:
kwargs['user'] = self.whoami()
if 'oldpw' not in kwargs:
kwargs['oldpw'] = None
self._conn.passwd_s(kwargs['user'], kwargs['oldpw'], newpw)
def ping(self, reconnect=True):
try:
self.whoami()
return True
except ldap.SERVER_DOWN:
return False
def search(self, filterstr = '(objectClass=*)', **kwargs):
if not 'basedn' in kwargs:
kwargs['basedn'] = self.basedn
if not 'scope' in kwargs:
kwargs['scope'] = ldap.SCOPE_SUBTREE
if not 'attrs' in kwargs:
kwargs['attrs'] = None
if not 'timeout' in kwargs:
kwargs['timeout'] = self._conn.timeout
if not 'entryclass' in kwargs:
kwargs['entryclass'] = LdapEntry
msgid = self._conn.search_ext(kwargs['basedn'], kwargs['scope'],
filterstr, kwargs['attrs'], 0, None,
None, kwargs['timeout'], 1000000)
return LdapResult(self, msgid, kwargs['entryclass'])
def searchone(self, filterstr, **kwargs):
result = self.search(filterstr, **kwargs)
if len(result) == 0:
raise ldap.NO_SUCH_OBJECT({'desc': 'No such object'})
elif len(result) == 1:
return result[0]
else:
raise ldap.NO_SUCH_OBJECT({'desc': 'Multiple objects matched'})
def set_option(self, option, value):
self._conn.set_option(option, value)
def unbind(self):
self._conn.unbind_s()
self.auth = (None, None)
def whoami(self):
try:
return self._conn.whoami_s().split(':')[1]
except IndexError:
return None
class LdapResult:
class Iterator:
def __init__(self, result):
self._result = result
self._position = -1
def next(self):
try:
r = self._result[self._position + 1]
self._position = self._position + 1
except IndexError:
raise StopIteration
return r
def __init__(self, conn, msgid, entryclass = LdapEntry):
self._conn = conn
self._msgid = msgid
self._entries = []
self._entryclass = entryclass
def __del__(self):
try:
self._conn._conn.abandon(self._msgid)
except (ldap.LDAPError, TypeError):
pass
def __getitem__(self, idx):
while len(self._entries)-1 < idx:
try:
self._next()
except StopIteration:
raise IndexError('list index out of range')
return self._entries[idx]
def __iter__(self):
return self.Iterator(self)
def __len__(self):
while True:
try:
self._next()
except StopIteration:
break
return len(self._entries)
def __str__(self):
return str(self.keys())
def _next(self):
if self._msgid is None:
raise StopIteration
try:
(rtype, entry) = self._conn._conn.result(self._msgid, all=0,
timeout=-1)
if rtype == ldap.RES_SEARCH_RESULT:
self._msgid = None
raise StopIteration
except IndexError:
raise StopIteration
self._entries.append(self._entryclass(self._conn, entry[0]))
return self._entries[len(self._entries)-1]
def keys(self):
r = []
for entry in self:
r.append(entry.dn)
return r