# coding: utf-8

from ldap3 import Server, Connection, SIMPLE, SUBTREE

LDAP_PAGE = 1000


class LdapError(Exception): pass


class LdapRes:
    def __init__(self, dn, attrib):
        self.dn = dn
        self.attr = attrib

    def __getitem__(self, item):
        return self.attr[item]

    def __iter__(self):
        return iter(self.attr)

    def __repr__(self):
        return '<LdapRes: dn: %s>' % self.dn

    @classmethod
    def fromLdapQuery(cls, q):
        if not isinstance(q, dict):
            raise LdapError('LdapRes: Parsing Error, not ldap response item')
        if not (('dn' in q) and ('attributes' in q)):
            raise LdapError('LdapRes: Parsing Error, format mismatch')

        return cls(q['dn'], q['attributes'])


class Ldap:
    def __init__(self, host, user, passwd, timeout=60, queryTimeout=300, **kwa):
        if 'baseDN' in kwa:
            self._baseDN = kwa['baseDN']
            del kwa['baseDN']
        else:
            self._baseDN = None
        ldapSrv = Server(host, connect_timeout=timeout, **kwa)
        self._conn = self._makeConnFabric(ldapSrv, authentication=SIMPLE,
                                          user=user, password=passwd,
                                          check_names=True, lazy=True,
                                          auto_referrals=False, raise_exceptions=True, auto_range=True
                                          )
        self.queryTimeout = queryTimeout

    def __call__(self, filter, attrib, queryTimeout=None, baseDN=None):
        if baseDN is None:
            if self._baseDN is None:
                raise LdapError('No base dn on query execution')
            baseDN = self._baseDN
        if queryTimeout is None:
            queryTimeout = self.queryTimeout
        try:
            conn = self._conn()
            with conn:
                conn.open()
                conn.bind()

                res = conn.extend.standard.paged_search(baseDN,
                                                        filter, attributes=attrib, paged_size=LDAP_PAGE,
                                                        generator=False,
                                                        search_scope=SUBTREE, time_limit=queryTimeout
                                                        )

                for i in res:
                    if i['type'] == 'searchResEntry':
                        yield LdapRes.fromLdapQuery(i)

        except Exception as e:
            raise LdapError("Error on get data (%s): %s" % (type(e), str(e)), *e.args[1:])

    def getList(self, *a, **kwa):
        return [i for i in self(*a, **kwa)]

    @staticmethod
    def _makeConnFabric(*a, **kwa):
        def _func():
            return Connection(*a, **kwa)

        return _func
