# coding: utf-8
"""\
Инструмент миграции схемы БД.

DOC: https://ws.a0fs.net/w/ncc.ws/devel/python/modules/db_migrator/
"""

from os.path import exists, join as p_join, isdir
from os import listdir


class MigrateError(Exception): pass


class MigrateManager(object):
    def __init__(self, control_table: str, migrate_env: str):
        self.control_table = control_table

        if not exists(migrate_env):
            raise MigrateError('Migrate enviroment not found')

        self.schema = p_join(migrate_env, 'schema.sql')
        if not exists(self.schema):
            raise MigrateError('Schema file not found: %s' % self.schema)

        self.patch_dir = p_join(migrate_env, 'patch')
        if not isdir(self.patch_dir):
            raise MigrateError('Patch dir not found or not directory: %s' % self.patch_dir)

    def get_patch_files(self, ver: int):
        res = {}
        for f in listdir(self.patch_dir):
            if not f.lower().endswith('.sql'):
                continue

            _f = f.strip().split('.')

            try:
                _ver = int(_f[0])

            except (TypeError, ValueError) as e:
                raise MigrateError('Error on parse version "%(ver)s" of file "%(f)s": %(e)s' % {
                    'ver': _f[0],
                    'f': f,
                    'e': e,
                })

            except IndexError:
                raise MigrateError('Error on get version from filename: %s' % f)

            if _ver in res:
                raise MigrateError('Version duplicates on parse file: %s' % f)

            res[_ver] = p_join(self.patch_dir, f)

        for i in sorted(res.keys()):
            if i > ver:
                yield i, res[i]

    @staticmethod
    def get_commands(file):
        buf = []
        with open(file) as IN:
            for l in IN:
                if l.lstrip().startswith('--'):
                    if buf:
                        yield '\n'.join(buf)
                        buf[:] = []

                else:
                    buf.append(l)

        if buf:
            yield '\n'.join(buf)

    def init_db(self, db):
        cursor = db.cursor()
        for c in self.get_commands(self.schema):
           cursor.execute(c)
           db.commit()

        db.commit()

    def check(self, db):
        cursor = db.cursor()
        cursor.execute("SELECT version FROM %s" % self.control_table)
        q = cursor.fetchone()
        del cursor

        if q is None:
            ver = -1
        else:
            ver = int(q[0])

        new_ver = ver
        cursor = db.cursor()
        for up_ver, patch_file in self.get_patch_files(ver):
            new_ver = up_ver
            for cmd in self.get_commands(patch_file):
                cursor.execute(cmd)
                db.commit()

        cursor.execute("""
            UPDATE %s
            SET version = %s
        """ % (self.control_table, new_ver))


