From fd2084b8601a670cb4ca75d78da46a6ec433cf4b Mon Sep 17 00:00:00 2001 From: ringzero Date: Tue, 24 Jan 2017 10:41:18 +0800 Subject: [PATCH] PreparedStatement anti sql injection --- libmysql.py | 56 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/libmysql.py b/libmysql.py index 51c54b1..23de820 100644 --- a/libmysql.py +++ b/libmysql.py @@ -41,18 +41,22 @@ def insert(self, table, data): params = self.join_field_value(data); sql = "INSERT IGNORE INTO {table} SET {params}".format(table=table, params=params) - result = cursor.execute(sql) - self.connection.commit() # not autocommit - + result = cursor.execute(sql, tuple(data.values())) + self.connection.commit() return result def delete(self, table, condition=None, limit=None): - """mysql delete() function""" + """ + mysql delete() function + sql.PreparedStatement method + """ with self.connection.cursor() as cursor: + prepared = [] # PreparedStatement if not condition: where = '1'; elif isinstance(condition, dict): where = self.join_field_value( condition, ' AND ' ) + prepared.extend(condition.values()) else: where = condition @@ -60,38 +64,46 @@ def delete(self, table, condition=None, limit=None): sql = "DELETE FROM {table} WHERE {where} {limits}".format( table=table, where=where, limits=limits) - result = cursor.execute(sql) + result = cursor.execute(sql, tuple(prepared)) self.connection.commit() # not autocommit return result def update(self, table, data, condition=None): - """mysql update() function""" + """ + mysql update() function + Use sql.PreparedStatement method + """ with self.connection.cursor() as cursor: + prepared = [] # PreparedStatement params = self.join_field_value(data) + prepared.extend(data.values()) if not condition: where = '1'; elif isinstance(condition, dict): where = self.join_field_value( condition, ' AND ' ) + prepared.extend(condition.values()) else: where = condition sql = "UPDATE {table} SET {params} WHERE {where}".format( table=table, params=params, where=where) - result = cursor.execute(sql) + result = cursor.execute(sql, tuple(prepared)) self.connection.commit() # not autocommit - return result def count(self, table, condition=None): """count database record""" with self.connection.cursor() as cursor: + prepared = [] # PreparedStatement + # WHERE CONDITION if not condition: where = '1'; elif isinstance(condition, dict): where = self.join_field_value( condition, ' AND ' ) + prepared.extend(condition.values()) else: where = condition @@ -100,15 +112,16 @@ def count(self, table, condition=None): table=table, where=where) # EXECUTE SELECT COUNT sql - cursor.execute(sql) + cursor.execute(sql, tuple(prepared)) # RETURN cnt RESULT return cursor.fetchone().get('cnt') - def fetch_rows(self, table, fields=None, condition=None, order=None, limit=None, fetchone=False): """mysql select() function""" with self.connection.cursor() as cursor: + prepared = [] # PreparedStatement + # SELECT FIELDS if not fields: fields = '*' @@ -123,6 +136,7 @@ def fetch_rows(self, table, fields=None, condition=None, order=None, limit=None, where = '1'; elif isinstance(condition, dict): where = self.join_field_value( condition, ' AND ' ) + prepared.extend(condition.values()) else: where = condition @@ -135,13 +149,9 @@ def fetch_rows(self, table, fields=None, condition=None, order=None, limit=None, # LIMIT NUMS limits = "LIMIT {limit}".format(limit=limit) if limit else "" sql = "SELECT {fields} FROM {table} WHERE {where} {orderby} {limits}".format( - fields=fields, - table=table, - where=where, - orderby=orderby, - limits=limits) + fields=fields, table=table, where=where, orderby=orderby, limits=limits) - cursor.execute(sql) + cursor.execute(sql, tuple(prepared)) if fetchone: return cursor.fetchone() @@ -160,19 +170,17 @@ def query(self, sql, fetchone=False): else: return cursor.fetchall() - def close(self): - if self.connection: - return self.connection.close() - def join_field_value(self, data, glue = ', '): sql = comma = '' - for key, value in data.iteritems(): - if isinstance(value, str): - value = pymysql.escape_string(value) - sql += "{0}`{1}` = '{2}'".format(comma, key, value) + for key in data.keys(): + sql += "{}`{}` = %s".format(comma, key) comma = glue return sql + def close(self): + if self.connection: + return self.connection.close() + def __del__(self): """close mysql database connection""" self.close()