/* $PostgresPy: if/src/query.c,v 1.3 2004/07/27 16:18:14 flaw Exp $
 *
 * † Instrument:
 *     Copyright 2004, rhid development. All Rights Reserved.
 *     
 *     Usage of the works is permitted provided that this
 *     instrument is retained with the works, so that any entity
 *     that uses the works is notified of this instrument.
 *     
 *     DISCLAIMER: THE WORKS ARE WITHOUT WARRANTY.
 *     
 *     [2004, Fair License; rhid.com/fair]
 *     
 *//*
 * Postgres' Cursor Interface [For Python]
 *
 * This implements a DB-API 2.0 compatible cursor type.
 */
#include <pputils.h>

#include <postgres.h>
#include <miscadmin.h>
#include <access/heapam.h>
#include <catalog/pg_type.h>
#include <executor/executor.h>
#include <executor/execdesc.h>
#include <nodes/params.h>
#include <tcop/tcopprot.h>
#include <tcop/dest.h>
#include <tcop/pquery.h>
#include <utils/array.h>
#include <utils/palloc.h>
#include <utils/relcache.h>
#include <utils/portal.h>
#include <pg.h>
#include <PGExcept.h>

#include <Python.h>
#include <structmember.h>
#include <py.h>

#include "datum.h"
#include "tupd.h"
#include "tup.h"
#include "obj.h"
#include "rel.h"
#include "module.h"
#include "utils.h"

#include "query.h"
#include "portal.h"

static PyMemberDef PyPgQuery_Members[] = {
	{"lastoid", T_UINT, offsetof(struct PyPgQuery, lastoid), RO, "Last Oid"},
	{NULL}
};

/*
 * Initialize a query to ready it for execution
 */
static void
prepare(PgQuery self)
{
	volatile List *ptl;
	volatile ListCell *pti;
	
	volatile List *ql = NIL;
	volatile List *pl = NIL;

	/* XXX: Set query memory context? then duplicate lists in plcontext? */

	PgErr_TRAP(
		ptl = pg_parse_query(PYASSTR(self->str));
	);
	if (PyErr_Occurred())
		return;

	foreach(pti, ptl)
	{
		Node *ptree = (Node *) lfirst(pti);
		List *qlist = NULL;
		ListCell *qli = NULL;
	
		PgErr_TRAP(
			qlist = pg_analyze_and_rewrite(ptree, self->argtypes, self->nargs);
		);

		foreach(qli, qlist)
		{
			Query *qtree = (Query *) lfirst(qli);
			Plan *planTree;
			
			if (qtree->commandType == CMD_UTILITY)
			{
				Node *us = qtree->utilityStmt;

				if (IsA(us, DeclareCursorStmt) ||
					   IsA(us, ClosePortalStmt) ||
						IsA(us, FetchStmt)
						/* XXX: if !nested xacts */
						|| IsA(us, TransactionStmt))
				{
					PyErr_Format(PyExc_PgErr,
						"Utility statement (%d) is prohibitted", (int) us);
					return;
				}

				planTree = NULL;
			}

			PgErr_TRAP(
				planTree = pg_plan_query(qtree);
			);
			pl = lappend(pl, planTree);
			ql = lappend(ql, qtree);
		}
	}

	self->queryl = ql;
	self->planl = pl;
}

/*
 * execute
 * 	A simple function for non-SELECT statements.
 * 	All SELECT statements are portalized, except SELECT INTO.
 */
static uint32
execute(PgQuery self, Query *query, Plan *plan, ParamListInfo pli)
{
	QueryDesc *qd;
	uint32 processed;

	qd = CreateQueryDesc(query, plan, None_Receiver, pli, false);

	PgErr_TRAP( ExecutorStart(qd, true, false); );
	if (PyErr_Occurred())
		goto end;

	ExecutorRun(qd, ForwardScanDirection, 0);
	if (PyErr_Occurred())
		goto end;
	
	processed = qd->estate->es_processed;
	self->lastoid = qd->estate->es_lastoid;

	ExecutorEnd(qd);
end:
	FreeQueryDesc(qd);
	return(processed);
}

static void
query_dealloc(PgQuery self)
{
	self->lastoid = InvalidOid;
	self->nargs = 0;

	if (self->argtypes)
		pfree(self->argtypes);

	freeList(self->queryl);
	self->queryl = NULL;

	freeList(self->planl);
	self->planl = NULL;

	DECREF(self->str);
	self->str = NULL;

	PyObject_Del((PyObj)self);
}

static PyObj
query_repr(PgQuery self)
{
	return(PYSTRF("<Postgres.Query:%s@%p>", PYASSTR(self->str), self));
}

static PyObj
query_call(PgQuery self, PyObj args, PyObj kw)
{
	unsigned long q;
	PyObj ob = NULL, rob = NULL;

	Datum *dargs;
	bool *nulls;
	ParamListInfo pli;

	ListCell *queryc, *planc;
	List *queryl, *planl;
	
	if (PyCFunctionErr_NoKeywordsAllowed(kw))
		return(NULL);

	q = PySequence_Length(args);
	if (self->nargs != q)
	{
		PyErr_Format(PyExc_TypeError,
			"This query requires %lu arguments; given %lu.",
			self->nargs, q
		);

		return(NULL);
	}

	if (q > 0)
	{
		unsigned long i;

		dargs = palloc0(sizeof(Datum) * q);
		nulls = palloc0(sizeof(bool) * q);
		
		for (i = 0; i < q; ++i)
		{
			PyObj arg;

			arg = PySequence_GetItem(args, i);
		}
	}
	else
	{
		dargs = NULL;
		nulls = NULL;
	}
	pli = ParamListInfo_FromDatumsAndNulls(self->nargs, dargs, nulls);
	
	queryl = self->queryl;
	planl = self->planl;

	forboth(queryc, queryl, planc, planl)
	{
		Query *query = (Query *) lfirst(queryc);
		Plan *plan = (Plan *) lfirst(planc);

		if (query->commandType == CMD_SELECT && query->into == NULL)
		{
			rob = (PyObj)PgPortal_New(self, query, plan, pli);
		}
		else
		{
			unsigned long processed;
			processed = execute(self, query, plan, pli);

			if (query->commandType != CMD_UTILITY)
				ob = PYLONG(processed);
		}

		/* Utility statements don't return anything, ob is NULL */
		if (ob)
		{
			if (rob)
			{
				if (PyList_CheckExact(rob))
					PyList_Append(rob, ob);
				else
				{
					PyObj l = PyList_New(0);
					PyList_Append(l, rob);
					DECREF(rob);
					PyList_Append(l, ob);
				}

				DECREF(ob);
			}
			else
				rob = ob;

			ob = NULL;
		}
	}

	if (q > 0)
	{
		pfree(dargs);
		pfree(nulls);
		pfree(pli);
	}

	if (rob)
		return(rob);
	else
		RETURN_NONE;
}

static int
query_init(PgQuery self, PyObj args, PyObj kw)
{
	unsigned long i, q;
	PyObj qstr;

	self->lastoid = InvalidOid;

	if (PyCFunctionErr_NoKeywordsAllowed(kw))
		return(-1);
	
	q = PySequence_Length(args);
	if (q > 1)
	{
		PyObj item;
		Oid *at;
		
		at = palloc(sizeof(Oid) * q);

		for (i = 1; i < q; ++i)
		{
			item = PySequence_GetItem(args, i);
			at[i] = TypeOid_FromPyObject(item);
		}

		self->nargs = q;
		self->argtypes = at;
	}
	else
	{
		self->nargs = 0;
		self->argtypes = NULL;
	}

	qstr = PyObject_Str(PySequence_GetItem(args, 0));
	if (qstr)
	{
		self->str = qstr; /* Give the reference to the query object */
		prepare(self);
		if (PyErr_Occurred())
			return(-1);
	}
	else
		PyErr_Format(PyExc_TypeError, "failed to setup query string");

	return(0);
}

static PyObj
query_str(PgQuery self)
{
	INCREF(self->str);
	return(self->str);
}

static char PyPgQuery_Doc[] = "Postgres Query";
PyTypeObject PyPgQuery_Type = {
	PyObject_HEAD_INIT(NULL)
	0,										/* ob_size */
	"Postgres.Query",					/* tp_name */
	sizeof(struct PyPgQuery),		/* tp_basicsize */
	0,										/* tp_itemsize */
	(destructor)query_dealloc,		/* tp_dealloc */
	NULL,									/* tp_print */
	(getattrfunc)NULL,				/* tp_getattr */
	(setattrfunc)NULL,				/* tp_setattr */
	(cmpfunc)NULL,						/* tp_compare */
	(reprfunc)query_repr,			/* tp_repr */
	NULL,									/* tp_as_number */
	NULL,									/* tp_as_sequence */
	NULL,									/* tp_as_mapping */
	(hashfunc)NULL,					/* tp_hash */
	(ternaryfunc)query_call,		/* tp_call */
	(reprfunc)query_str,				/* tp_str */
	NULL,									/* tp_getattro */
	NULL,									/* tp_setattro */
	NULL,									/* tp_as_buffer */
	Py_TPFLAGS_DEFAULT,			   /* tp_flags */
	(char*)PyPgQuery_Doc,			/* tp_doc */
	(traverseproc)NULL,				/* tp_traverse */
	(inquiry)NULL,						/* tp_clear */
	(richcmpfunc)NULL,				/* tp_richcompare */
	(long)0,								/* tp_weaklistoffset */
	(getiterfunc)NULL,				/* tp_iter */
	(iternextfunc)NULL,				/* tp_iternext */
	NULL,									/* tp_methods */
	PyPgQuery_Members,				/* tp_members */
	NULL,									/* tp_getset */
	NULL,									/* tp_base */
	NULL,									/* tp_dict */
	NULL,									/* tp_descr_get */
	NULL,									/* tp_descr_set */
	NULL,									/* tp_dictoffset */
	(initproc)query_init,			/* tp_init */
	NULL,									/* tp_alloc */
	PyType_GenericNew,				/* tp_new */
};

PyObj
PyPgQuery_New(PyObj query, PyObj args)
{
	PgQuery rob;

	rob = PgQuery_NEW();
	if (rob == NULL)
		return(NULL);

	rob->ob_refcnt = 1;
	if (query_init(rob, args, NULL) < 0)
	{
		PyObject_Del(rob);
		return(NULL);
	}

	return((PyObj)rob);
}
