import pymysql import json class MaDbHelper: def __init__(self, config): self.mysqlConn = pymysql.connect(host=config["MYSQL_HOST"], user=config["MYSQL_USER"], passwd=config["MYSQL_PASSWORD"], db=config["MYSQL_DATABASE"], port=config["MYSQL_PORT"]) self.mysqlCur = self.mysqlConn.cursor() self._initDb() def _initDb(self): """ Generates required tables """ query = """ CREATE TABLE IF NOT EXISTS vmq_auth_acl ( mountpoint VARCHAR(10) NOT NULL, client_id VARCHAR(128) NOT NULL, username VARCHAR(128) NOT NULL, password VARCHAR(128), publish_acl TEXT, subscribe_acl TEXT, CONSTRAINT vmq_auth_acl_primary_key PRIMARY KEY (mountpoint, client_id, username) ) """ self.mysqlCur.execute( query ) self.mysqlConn.commit() def addUser(self, username, password, publishAclPatterns, subscribeAclPatterns): if self.userExists(username): return False query = "INSERT INTO `vmq_auth_acl` (`mountpoint`, `client_id`, `username`, `password`, `publish_acl`, `subscribe_acl`) VALUES (%s, %s, %s, PASSWORD(%s), %s, %s);" self.mysqlCur.execute( query, ( "", username, username, password, self._convertAclPatternList(publishAclPatterns), self._convertAclPatternList(subscribeAclPatterns) ) ) self.mysqlConn.commit() return True def userExists(self, username): query = "SELECT username FROM `vmq_auth_acl` WHERE username=%s" self.mysqlCur.execute( query, (username) ) result = self.mysqlCur.fetchone() return result is not None def updateUser(self, username, password, publishAclPatterns, subscribeAclPatterns): if self.userExists(username): return self._updateUser(username, password, publishAclPatterns, subscribeAclPatterns) else: return False def getAllUsers(self): query = "SELECT username FROM `vmq_auth_acl`;" self.mysqlCur.execute( query ) users = [] for user in self.mysqlCur: users.append(user[0]) return users def _updateUser(self, username, password, publishAclPatterns, subscribeAclPatterns): query = """ UPDATE `vmq_auth_acl` SET `password`=PASSWORD(%s), `publish_acl`=%s, `subscribe_acl`=%s WHERE `username`=%s; """ self.mysqlCur.execute( query, ( password, self._convertAclPatternList(publishAclPatterns), self._convertAclPatternList(subscribeAclPatterns), username ) ) self.mysqlConn.commit() def _convertAclPatternList(self, patternList): patternMapLabda = lambda pattern: {"pattern": pattern} patternList = list(map(patternMapLabda, patternList)) return json.dumps(patternList)