# @package      hubzero-python
# @file         db.py
# @author       David Benham <dbenham@purdue.edu>
# @copyright    Copyright (c) 2012 HUBzero Foundation, LLC.
# @license      http://www.gnu.org/licenses/lgpl-3.0.html LGPLv3
#
# Copyright (c) 2012 HUBzero Foundation, LLC.
#
# This file is part of: The HUBzero(R) Platform for Scientific Collaboration
#
# The HUBzero(R) Platform for Scientific Collaboration (HUBzero) is free
# software: you can redistribute it and/or modify it under the terms of
# the GNU Lesser General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# HUBzero is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

import exceptions
import itertools
import sys

# Guess there is a chance this won't be installed
try:
	import MySQLdb
except ImportError:
	sys.stderr.write("Error: Cannot import MySQLdb library")
	sys.exit(1)


class MySQLConnection:
	""" Very lightweight MySQL library """

	_db = None
	_dbHost = ""
	_dbName = ""
	_dbPW = ""
	_dbUsername = ""


	def _db_connect(self):
		"""internal connect method"""
		
		self._db = MySQLdb.connect(host=self._dbHost, 
		                          user=self._dbUsername, 
		                          db=self._dbName,
		                          passwd=self._dbPW)


	def __init__(self, dbHost, dbName, dbUsername, dbPW):
		""" Constructor """
		
		self._dbHost = dbHost
		self._dbName = dbName
		self._dbPW = dbPW
		self._dbUsername = dbUsername

		try:
			self._db_connect()
		except Exception:
			sys.stderr.write("Error connecting to database: " + dbName + "for user: " + dbUsername)
			

	def __del__(self):
		""" Destructor """
		self.close()
	
	
	def close(self):
		if self._db is not None:
			self._db.close()
			self._db = None


	def _cursor(self):
		return self._db.cursor()


	def _execute(self, cursor, sql, parms):
		"""
		we do this in multiple places and it's nice to be consistent with handling errors
		"""
		try:
			return cursor.execute(sql, parms)
		except Exception:
			sys.stderr.write("Error executing cursor on database: " + self._dbName + "for user: " + self._dbUsername)
			self.close()
			raise


	def query(self, query, parms):
		""" Returns a list of rows for the given query and parameters """
		cursor = self._cursor()
		try:
			self._execute(cursor, query, parms)

			columnNames = []
			for column in cursor.description:
				columnNames.append(column[0])

			# return list of dictWithAttributes objects, zip with column names
			# dictWithAttributes custom object allows us to to obj.prop style notation with value returned by this function
			# and avoid object["prop"] style ugly lookups
			return [dictWithAttributes(zip(columnNames, row)) for row in cursor]		
		finally:
			cursor.close()

	
	def query_lastrowid(self, sql, parms):
		"""
		Single insert query that returns the pkid of the last inserted row
		"""
		cursor = self._cursor()
		try:
			self._execute(cursor, sql, parms)
			return cursor.lastrowid
		finally:
			cursor.close()

		
	def query_rowcount(self, sql, parms):
		"""
		exec query and return the rowcount
		"""
		cursor = self._cursor()
		try:
			self._execute(cursor, sql, parms)
			return cursor.rowcount
		finally:
			cursor.close()


	def query_selectscalar(self, sql, parms):
		"""
		When you're looking for an easy way to get only a single value from a query 
		"""
		cursor = self._cursor()
		self._execute(cursor, sql, parms)
		r = cursor.fetchall()

		print "\n\nsql=" + sql
		print "parms"
		print parms

		if len(r) != 1:
			return None
		else:
			return r[0][0]
		
		
class dictWithAttributes(dict):
	"""
	helps support object.property style notation in returned query results
	"""
	def __getattr__(self, name):
		return self[name]
