1
0
mirror of https://github.com/kakwa/ldapcherry synced 2024-11-25 18:54:29 +01:00

fix issue related to python-ldap returning lists

Before, no particular treatment was done on the user attributes.
This caused some issues because python-ldap systematically returns
the attribute value as a list (even if it's mono-valuated).

Now we recover the attributes used in the group attr templates,
and we "normalize" the user attributes before using it in add_to_groups
and del_from_groups.

By normalize, we mean, transforming the list to it's unique value.
In case the attribute doesn't exist or is multi-valuated, it raises an
error.
This commit is contained in:
kakwa 2017-03-16 02:45:23 +01:00
parent 55ce2bec5e
commit 3fd6dcee82

View File

@ -11,6 +11,7 @@ import ldap.modlist as modlist
import ldap.filter import ldap.filter
import logging import logging
import ldapcherry.backend import ldapcherry.backend
from sets import Set
from ldapcherry.exceptions import UserDoesntExist, \ from ldapcherry.exceptions import UserDoesntExist, \
GroupDoesntExist, \ GroupDoesntExist, \
UserAlreadyExists UserAlreadyExists
@ -23,6 +24,21 @@ class CaFileDontExist(Exception):
self.cafile = cafile self.cafile = cafile
self.log = "CA file %(cafile)s does not exist" % {'cafile': cafile} self.log = "CA file %(cafile)s does not exist" % {'cafile': cafile}
class MissingGroupAttr(Exception):
def __init__(self, attr):
self.attr = attr
self.log = "User doesn't have %(attr)s in its attributes" \
", cannot use it to set group" % {'attr': attr}
class MultivaluedGroupAttr(Exception):
def __init__(self, attr):
self.attr = cafile
self.log = "User's attribute '%(attr)s' is multivalued" \
", cannot use it to set group" % {'attr': attr}
NO_ATTR = 0 NO_ATTR = 0
DISPLAYED_ATTRS = 1 DISPLAYED_ATTRS = 1
LISTED_ATTRS = 2 LISTED_ATTRS = 2
@ -58,10 +74,14 @@ class Backend(ldapcherry.backend.Backend):
for o in re.split('\W+', self.get_param('objectclasses')): for o in re.split('\W+', self.get_param('objectclasses')):
self.objectclasses.append(self._str(o)) self.objectclasses.append(self._str(o))
self.group_attrs = {} self.group_attrs = {}
self.group_attrs_keys = Set([])
for param in config: for param in config:
name, sep, group = param.partition('.') name, sep, group = param.partition('.')
if name == 'group_attr': if name == 'group_attr':
self.group_attrs[group] = self.get_param(param) self.group_attrs[group] = self.get_param(param)
self.group_attrs_keys |= Set(
self._extract_format_keys(self.get_param(param))
)
self.attrlist = [] self.attrlist = []
for a in attrslist: for a in attrslist:
@ -135,6 +155,39 @@ class Backend(ldapcherry.backend.Backend):
) )
raise raise
def _extract_format_keys(self, fmt_string):
"""Extract the keys of a format string
(the 'stuff' in '%(stuff)s'
"""
class AccessSaver:
def __init__(self):
self.keys = []
def __getitem__(self, key):
self.keys.append(key)
a = AccessSaver()
fmt_string % a
return a.keys
def _normalize_group_attrs(self, attrs):
"""Normalize the attributes used to set groups
If it's a list of one element, it just become this
element.
It raises an error if the attribute doesn't exist
or if it's multivaluated.
"""
for key in self.group_attrs_keys:
pass
if key not in attrs:
raise MissingGroupAttr(key)
if type(attrs[key]) is list and len(attrs[key]) == 1:
attrs[key] = attrs[key][0]
if type(attrs[key]) is list and len(attrs[key]) != 1:
raise MultivaluedGroupAttr(key)
def _connect(self): def _connect(self):
"""Initialize an ldap client""" """Initialize an ldap client"""
ldap_client = ldap.initialize(self.uri) ldap_client = ldap.initialize(self.uri)
@ -392,6 +445,7 @@ class Backend(ldapcherry.backend.Backend):
dn = tmp[0] dn = tmp[0]
attrs = tmp[1] attrs = tmp[1]
attrs['dn'] = dn attrs['dn'] = dn
self._normalize_group_attrs(attrs)
dn = self._str(tmp[0]) dn = self._str(tmp[0])
# add user to all groups # add user to all groups
for group in groups: for group in groups:
@ -447,6 +501,7 @@ class Backend(ldapcherry.backend.Backend):
dn = tmp[0] dn = tmp[0]
attrs = tmp[1] attrs = tmp[1]
attrs['dn'] = dn attrs['dn'] = dn
self._normalize_group_attrs(attrs)
dn = self._str(tmp[0]) dn = self._str(tmp[0])
for group in groups: for group in groups:
group = self._str(group) group = self._str(group)