/* 
   Copyright (C) 1999 E. H. Haley

   trwth.c heads towards cluster-weighting the layers.

   maybe for diagnostic have it draw each whole layer in init as it's done

 */

#include <stdio.h>
#include <stdlib.h>
#include <jpeglib.h>
#include "lib/loadj.h"
#include <GL/glut.h>
#include <X11/Xlib.h>
#include "lib/tile.h"
#include <math.h>

#define NCLU 10
#define RTP 2.5066283 //root two pi

int dw,dh;
int nlayers;
JSAMPLE *mix, *layers;
float *oryiq;
float mux[NCLU][2], sigy[NCLU];
float pc[NCLU], norm[NCLU];
float cov[NCLU][2][2], invcov[NCLU][2][2], rdcim[NCLU];
float *px_c, *py_xc, *pc_yx;

void init(int aardc, char **aardv)
{
  tile *dcont; //display_contender
  tfile tatch;
  int p,w,h,c, ow,oh,oc, x,y, l,i,j,k, check, zint,exp, level, side, logw;
  FILE *matchfile;
  JSAMPLE *arr, *temp, *original;

  glPixelStorei(GL_UNPACK_ALIGNMENT, 1);
  glReadBuffer(GL_FRONT);

  if (aardc<3) exit(0);
  if ((matchfile=fopen(*++aardv,"r"))==NULL)
    { fprintf(stderr, "can't open %s\n",*aardv); exit(1); }

  check=fread(&tatch,sizeof(tfile),1,matchfile);
  printf("%d\n",tatch.nnodes);
  side=dw=dh=tatch.side;
  dcont=(tile *)malloc(tatch.nnodes*sizeof(tile));
  check=fread(dcont,tatch.nnodes*sizeof(tile),1,matchfile);
  fclose(matchfile);

  original=rjf(*++aardv,&ow,&oh,&oc);
  if((ow!=dw)||(oh!=dh))
    { fprintf(stderr,"Too confusing, I quit.\n"); exit(0); }
  oryiq=(float *)malloc(3*dw*dh*sizeof(float));
  for(i=0; i<dw*dh; i++)
    rgb256yiq(original+3*i,oryiq+3*i);
  free(original);

  logw = 0;
  while((side>>=1)>0) logw++;
  nlayers=logw-3;  //uses: going down to 16x16.

  //------------------

  layers=(JSAMPLE *)malloc(3*nlayers*dw*dh*sizeof(JSAMPLE));
  
  p=-1;
  level= nlayers;
  for(l=1; l>0; l<<=1) //forever
    {
      level--;
      for(i=0; i<l; i++)
	for(j=0; j<l; j++)
	  {
	    if (++p>=tatch.nnodes)
	      {
		free(dcont);
		return; //i.e. go to cwminit()
	      }

	    if((arr=rjf(dcont[p].fname,&w,&h,&c))!=NULL)
	      //	      glDrawPixels(w,h, c<2?GL_LUMINANCE:GL_RGB,GL_UNSIGNED_BYTE,arr);
	      {
		temp=rez(arr,w,h,dcont[p].exp,c); //for zoom
		for(x=0; x<dw/l; x++)
		  for(y=0; y<dh/l; y++)
		    //those should be the dimensions of the thing that comes out of rez
		    //maybe rez should alter the w&h passed in...
		    {
		      for(k=0; k<c; k++)
			layers[3*(dw*(dh*level+(y+j*dh/l)) +(x+i*dw/l))+k]=temp[c*(w*y+x)+k];
		      if(c==1)
			{
			  layers[3*(w*(nlayers*h+y) +x)+1]=0.0;
			  layers[3*(w*(nlayers*h+y) +x)+2]=0.0;
			}
		    }
		free(temp);
	      }	  
	    printf("p=%d; f=%s; orig w=h=%d; want=%d; exp=%d; score=%d\n",
		   p,
		   dcont[p].fname, 
		   w,
		   dw/l,
		   dcont[p].exp,
		   dcont[p].score);
	    
	    fflush(stdout);
	    free(arr);
	  }
   }      
}

void cwminit(void)
{
  int i,j, d,d2, m, NX=dw, NY=dh;

  //set up cwm init stuff

  px_c=(float *)malloc(NX*NY*NCLU*sizeof(float));
  py_xc=(float *)malloc(NX*NY*NCLU*sizeof(float));
  pc_yx=(float *)malloc(NX*NY*NCLU*sizeof(float));
  mix=(JSAMPLE *)malloc(3*NX*NY*sizeof(JSAMPLE));

  for(m=0; m<NCLU; m++)
    {
      for(d=0; d<2; d++)
        {
          mux[m][d]=20.0*((float)random()/((float)RAND_MAX+1.0))-10.0;
          sigy[m]=1.0; //?
        }
      for(d=0; d<2; d++)
        for(d2=0; d2<2; d2++)
          {
            cov[m][d][d2]=100.0*(1-(d+d2)%2); //e.g. {{100, 0}, {0, 100}} //100 bcs ~ sigx^2
            invcov[m][d][d2]=0.001*(1-(d+d2)%2); //e.g. {{1/100, 0}, {0, 1/100}}
          }
      rdcim[m] = sqrt(invcov[m][0][0]*invcov[m][1][1] - invcov[m][1][0]*invcov[m][0][1]);
      pc[m]=1.0/(float)NCLU; //?
    }
}

void idle(void)
{
  //do a cwm iteration

  int i,j,k, m, x,y, d,d2, NX=dw, NY=dh, term, thing, small=0.01;
  float num[3], denom; 

  for(m=0; m<NCLU; m++)
    {
      //      rdcim[m] = sqrt(invcov[m][0][0]*invcov[m][1][1] - invcov[m][1][0]*invcov[m][0][1]);
      for(i=0; i<NX; i++)
	for(j=0; j<NY; j++)
	  {
	    px_c[NX*(NY*m+j)+i]= rdcim[m]/RTP/RTP;
	    for(d=0; d<2; d++)
	      for(d2=0; d2<2; d2++)
		px_c[NX*(NY*m+j)+i]*=
		  exp(-((d==0?i:j)-mux[m][d])
		      *invcov[m][d][d2]
		      *((d2==0?i:j)-mux[m][d2])/2.0);
	  }
    }
  
  //update py_xc
  for(m=0; m<NCLU; m++)
    for(i=0; i<NX; i++)
      for(j=0; j<NY; j++)
	py_xc[NX*(NY*m+j)+i]=
	  exp(-(oryiq[3*(dw*j+i)]-layers[3*(dw*(dh*(m%nlayers)+j)+i)+0])
	      *(oryiq[3*(dw*j+i)]-layers[3*(dw*(dh*(m%nlayers)+j)+i)+0])
	      /(2.0*sigy[m]*sigy[m]))
	  /(RTP*sigy[m]);
  
  //update pc_yx
  for(i=0; i<NX; i++)
    for(j=0; j<NY; j++)
      {
	denom=0.0;
	for(m=0; m<NCLU; m++)
	  {
	    pc_yx[NX*(NY*m+j)+i]=py_xc[NX*(NY*m+j)+i]*px_c[NX*(NY*m+j)+i]*pc[m];
	    denom+=pc_yx[NX*(NY*m+j)+i];
	  }
	for(m=0; m<NCLU; m++)
	  pc_yx[NX*(NY*m+j)+i]/=denom;
      }
  
  //update pc
  for(m=0; m<NCLU; m++)
    {
      pc[m]=0.0;
      for(i=0; i<NX; i++)
	for(j=0; j<NY; j++)
	  pc[m]+=pc_yx[NX*(NY*m+j)+i];
      pc[m]/=NX*NY;
    }
  
  //get norm[m] = sum_i pc_yx[i][j]  ..maybe just use f(pc[m])?
  for(m=0; m<NCLU; m++)
    {
      norm[m]=0.0;
      for(i=0; i<NX; i++)
	for(j=0; j<NY; j++)
	  norm[m]+=pc_yx[NX*(NY*m+j)+i];
    }
  
  //update mux
  for(d=0; d<2; d++)
    for(m=0; m<NCLU; m++)
      {
	mux[m][d]=0.0;
	for(i=0; i<NX; i++)
	  for(j=0; j<NY; j++)
	    mux[m][d]+=(d==0?i:j)*pc_yx[NX*(NY*m+j)+i];
	mux[m][d]/=norm[m];
      }
  
  //update cov
  for(m=0; m<NCLU; m++)
    {
      for(d=0; d<2; d++)
	for(d2=0; d2<2; d2++)
	  {
	    cov[m][d][d2]=0.0;
	    for(i=0; i<NX; i++)
	      for(j=0; j<NY; j++)
		cov[m][d][d2]+=((d==0?i:j)-mux[m][d])*((d2==0?i:j)-mux[m][d2])*pc_yx[NX*(NY*m+j)+i];
	    cov[m][d][d2]/=norm[m];
	  }
      term=1.0/(cov[m][0][0]*cov[m][1][1] - cov[m][1][0]*cov[m][0][1]);
      invcov[m][0][0]= term*cov[m][1][1];
      invcov[m][1][1]= term*cov[m][0][0];
      invcov[m][1][0]= -term*cov[m][0][1];
      invcov[m][0][1]= -term*cov[m][1][0];          
      rdcim[m] = sqrt(invcov[m][0][0]*invcov[m][1][1] - invcov[m][1][0]*invcov[m][0][1]);
    }
  
  //update sigy -- to cluster-expec of <(y-beta_m)^2>_m
  for(m=0; m<NCLU; m++)
    {
      sigy[m]=0.0;
      for(i=0; i<NX; i++)
	for(j=0; j<NY; j++)
	  sigy[m]+=
	    (oryiq[3*(dw*j+i)+0]-layers[3*(dw*(dh*(m%nlayers)+j)+i)+0])
	    *(oryiq[3*(dw*j+i)+0]-layers[3*(dw*(dh*(m%nlayers)+j)+i)+0])
	    *pc_yx[NX*(NY*m+j)+i];
      sigy[m]/=norm[m];
      sigy[m]=pow(sigy[m],0.5);
      sigy[m]+=small;
    }

  //now still gotta calculate where that leaves us -- cluster-weight that model.
  for(i=0; i<dw; i++)
    for(j=0; j<dh; j++)
      {
	for(m=0; m<NCLU; m++)
	  {
	    for(d=0; d<2; d++)
	      for(d2=0; d2<2; d2++)
		thing += ( (float) ((d==0)?i:j) -mux[m][d])
		  *invcov[m][d][d2]
		  *( (float) ((d2==0)?i:j) -mux[m][d2]);
	    term = exp(-thing/2.0) * rdcim[m] * pc[m] /RTP/RTP;
	    for(k=0; k<3; k++)
	      num[k]+=term * layers[3*(dw*(dh*(m%nlayers)+j)+i)+k];
	    denom+=term;
	  }
	for(k=0; k<3; k++)
	  num[k]/=denom;
	yiq256rgb(num, mix+3*(dw*j+i));
      }
  glutPostRedisplay();
}

void display(void)
{
  //  glPixelZoom(1.0,1.0);
  //  glRasterPos2i(0,0);
  glDrawPixels(dw,dh, GL_RGB, GL_UNSIGNED_BYTE, mix);
  glFlush();
  glutSwapBuffers();
}

  
void mouse(int button, int state, int x, int y)
{
  free(oryiq);
  free(layers);
  free(mix);    
  exit(0);
}


int main(int aardc, char **aardv)
{
  glutInit(&aardc, aardv);
  init(aardc, aardv);
  cwminit();

  glutInitDisplayMode(GLUT_DOUBLE | GLUT_RGBA);
  glutInitWindowSize(dw,dh);
  glutCreateWindow(aardv[0]);
  glutDisplayFunc(display);
  glutMouseFunc(mouse);
  glutIdleFunc(idle);

  glViewport(0,0, (GLfloat)dw,(GLfloat)dh);
  glMatrixMode(GL_PROJECTION);
  glLoadIdentity();
  gluOrtho2D(0.0, (GLfloat)dw, 0.0, (GLfloat)dh);
  glMatrixMode(GL_MODELVIEW);
  glLoadIdentity();

  glutMainLoop();
  exit(0);
}

