/* Write a band plot in gnuplot form */
/* (c) MJR 2024, 2025 */

/* Valid command-line options:
 *
 * bands_gnu -- gnuplot output
 * bands_eps -- EPS output
 *
 * plus optional suffixes
 *
 * _mono  if spin polarised, all bands of same spin same colour
 * _dash  if spin polarised, use dash patterns to differential spins
 * _ns    do not label special points
 */

/* This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation, either version 3
 * of the Licence, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, see http://www.gnu.org/licenses/
 */

#include<stdio.h>
#include<math.h>
#include<stdlib.h>
#include<string.h>

#include "c2xsf.h"

static void bands_plot_gnu(FILE *outfile, struct unit_cell *c,
			   struct contents *m, struct kpts *k,
			   struct es *e, struct special_sc *spec,
			   struct symmetry *s,
			   double *k_dist, int *k_break, char *fmt);
static void bands_plot_eps(FILE *outfile, struct unit_cell *c,
			   struct contents *m, struct kpts *k,
			   struct es *e, struct special_sc *spec,
			   struct symmetry *s,
			   double *k_dist, int *k_break, char *fmt);

void bands_plot(FILE* outfile, struct unit_cell *c, struct contents *m,
		struct kpts *k, struct es *e, char *fmt){
  int i,ik,*k_break,okay,no_sp=0,exch=0;
  double *k_dist,dir[3],vtmp[3],x,tmp;
  struct special_sc *spec;
  struct symmetry s,s1;

  if ((k->path_nkpt==0)&&(k->n>0)&&(!e->path_eval)&&(e->eval)){
    exch=1;
    if (debug) fprintf(stderr,"promoting kpts to path kpts\n");
    k->path_nkpt=k->n;
    k->path_kpts=k->kpts;
    e->path_nbands=e->nbands;
    e->path_eval=e->eval;
  }
      
  if ((!e->path_eval)&&(k->path_nkpt>1)) {
    fprintf(stderr,"No eigenvalues to write!\n");
    return;
  }

  if (k->path_nkpt<1){
    fprintf(stderr,"No kpoints!\n");
    return;
  }

  if ((fmt)&&(strstr(fmt,"_ns"))) no_sp=1;
  spec=NULL;
  init_sym(&s);

  if (c->basis){
    if (!no_sp){
      spec=bspec_sc(c->basis);
      if (debug)
	fprintf(stderr,"Lattice detected: %s\n",spec->lat);

#ifdef SPGLIB
      init_sym(&s1);
      cspg_op(c,NULL,&s1,NULL,CSPG_SYM,tol);
      sym2ksym(&s1,&s);
      free_sym(&s1,1);
      if (debug) fprintf(stderr,
			 "Brillouin zone has %d symmetry operations\n",s.n);
#else
      s.n=2;
      s.ops=malloc(2*sizeof(struct sym_op));
      if (!s.ops) error_exit("malloc error in bands_plot");
      for(ik=0;ik<1;ik++)
	for(i=0;i<3;i++)
	  for(j=0;j<3;j++)
	    s.ops[ik].mat[i][j]=0;
      s.ops[0].tr=NULL;
      for(i=0;i<3;i++) s.ops[0].mat[i][i]=1;
      s.ops[1].tr=NULL;
      for(i=0;i<3;i++) s.ops[1].mat[i][i]=-1;
#endif
    }
  }

  if (k->path_nkpt==1){
    printf("(%f,%f,%f) -> %s\n",
	   k->path_kpts[0].frac[0],k->path_kpts[0].frac[1],
	   k->path_kpts[0].frac[2],
	   k_print(k->path_kpts[0].frac,spec,c,&s));
    return;
  }
    
  k_dist=malloc(k->path_nkpt*sizeof(double));
  k_break=malloc(k->path_nkpt*sizeof(int));
  if (!k_dist) error_exit("malloc error for k_dist in band_plot");
  if (!k_break) error_exit("malloc error for k_break in band_plot");

  for(i=0;i<3;i++)
    dir[i]=k->path_kpts[1].frac[i]-k->path_kpts[0].frac[i];
  tmp=1/sqrt(vmod2(dir));
  for(i=0;i<3;i++)
    dir[i]*=tmp;

  x=0;
  k_dist[0]=0;
  k_break[0]=1;
  for(i=1;i<k->path_nkpt-1;i++)
    k_break[i]=0;
  k_break[k->path_nkpt-1]=1;

  if (c->basis)
    addabs(k->path_kpts,k->path_nkpt,c->recip);
  else
    for(ik=0;ik<k->path_nkpt;ik++)
      for(i=0;i<3;i++)
	k->path_kpts[ik].abs[i]=k->path_kpts[ik].frac[i];
  
  for(ik=1;ik<k->path_nkpt;ik++){
    for(i=0;i<3;i++)
      vtmp[i]=k->path_kpts[ik].abs[i]-k->path_kpts[ik-1].abs[i];
    tmp=sqrt(vmod2(vtmp));
    x+=tmp;
    k_dist[ik]=x;
    for(i=0;i<3;i++)
      vtmp[i]=k->path_kpts[ik].frac[i]-k->path_kpts[ik-1].frac[i];
    tmp=1/sqrt(vmod2(vtmp));
    for(i=0;i<3;i++)
      vtmp[i]*=tmp;
    okay=1;
    for(i=0;i<3;i++)
      if (!aeq(dir[i],vtmp[i])) okay=0;
    if (!okay){
      k_break[ik-1]=1;
      for(i=0;i<3;i++)
	dir[i]=vtmp[i];
    }
  }

  x=1/x;
  for(i=1;i<k->path_nkpt;i++)
    k_dist[i]*=x;


  if (strstr(fmt,"_eps"))
    bands_plot_eps(outfile,c,m,k,e,spec,&s,k_dist,k_break,fmt);
  else
    bands_plot_gnu(outfile,c,m,k,e,spec,&s,k_dist,k_break,fmt);

  free(k_break);
  free(k_dist);
  free_sym(&s,1);

  if (exch==1){
    k->path_nkpt=0;
    k->path_kpts=NULL;
    e->path_nbands=0;
    e->path_eval=NULL;
  }
}

static void bands_plot_gnu(FILE *outfile, struct unit_cell *c,
			   struct contents *m, struct kpts *k,
			   struct es *e, struct special_sc *spec,
			   struct symmetry *s,
			   double *k_dist, int *k_break, char *fmt){
  int i,ik,ns,nb,ibreak,ss_off=0,ph=0,mono=0;
  char *spin_style=NULL,*cptr,*label;
  double scale;

  cptr=fmt;
  if (!strncmp(cptr,"--bands_",8)) cptr+=8;
  else if (!strncmp(cptr,"--phbands_",10)) {
    cptr+=10;
    ph=1;
  }
  else error_exit("unexpected fmt in bands_plot_gnu");

  if (!strncmp("mono",cptr,4)){
    mono=1;
    spin_style="ls";
    ss_off=1;
  }
  if (!strncmp("dash",cptr,4)){
    spin_style="dt";
    ss_off=2;
  }

  if (flags&AU)
    scale=1/H_eV;
  else
    scale=1;
  
  /* Print header */

  if (m->title) fprintf(outfile,"set title \"%s\"\n",m->title);
  fprintf(outfile,"set termoption enhanced\n");
  fprintf(outfile,"set style data linespoints\n");
  if (ph)
    fprintf(outfile,"set ylabel '{/Symbol w} (cm^{-1})'\n");
  else{
    if (flags&AU)
      fprintf(outfile,"set ylabel '{/Symbol e} (Ha)'\n");
    else
      fprintf(outfile,"set ylabel '{/Symbol e} (eV)'\n");
  }
  fprintf(outfile,"set xtics ( ");
  for(i=0;i<k->path_nkpt;i++){
    if (k_break[i]){
      if (i) fprintf(outfile,", ");
      label=k_print(k->path_kpts[i].frac,spec,c,s);
      if (!strcmp(label,"G"))
	fprintf(outfile,"'{/Symbol G}' %f",k_dist[i]);
      else if (!strcmp(label,"Sigma"))
	fprintf(outfile,"'{/Symbol S}' %f",k_dist[i]);
      else if (!strcmp(label,"Sigma1"))
	fprintf(outfile,"'{/Symbol S}1' %f",k_dist[i]);
      else
	fprintf(outfile,"'%s' %f",label,k_dist[i]);
      //      fprintf(outfile,"'%.2f,%.2f,%.2f' %f",k->kpts[i].frac[0]
      //	      ,k->kpts[i].frac[1],k->kpts[i].frac[2],k_dist[i]);
    }
  }
  fprintf(outfile,")\n");
  fprintf(outfile,"set xtics rotate by 315\n");
  fprintf(outfile,"set grid xtics lt -1\n");
  if ((e->nspins>1)&&(mono))
    fprintf(outfile,"set style line 1 lt rgb 'red'\n"
	    "set style line 2 lt rgb 'blue'\n"
	    "set style line 3 lt rgb 'green'\n"
	    "set style line 4 lt rgb 'magenta'\n");

  
  fprintf(outfile,"plot '-' notitle w l");
  if ((e->nspins>1)&&(spin_style))
    fprintf(outfile," %s %d",spin_style,ss_off);
  for(i=1;i<e->path_nbands*e->nspins;i++){
    fprintf(outfile,", \\\n '-' notitle w l");
    if ((e->nspins>1)&&(spin_style))
      fprintf(outfile," %s %d",spin_style,ss_off+i/e->path_nbands);
  }
  if (e->e_fermi)
    fprintf(outfile,", \\\n %f title 'E_{F}' lt 0 lc 'black'\n",
	    *e->e_fermi*scale);
  else
    fprintf(outfile,"\n");

  for(ns=0;ns<e->nspins;ns++){
    for(nb=0;nb<e->path_nbands;nb++){
      ibreak=-1;
      if (k->break_n) ibreak=0;
      for(ik=0;ik<k->path_nkpt;ik++){
	if ((ik>0)&&(ibreak>-1)){
	  if((aeq(k->path_kpts[ik-1].frac[0],k->breaks[6*ibreak]))&&
	     (aeq(k->path_kpts[ik-1].frac[1],k->breaks[6*ibreak+1]))&&
	     (aeq(k->path_kpts[ik-1].frac[2],k->breaks[6*ibreak+2]))&&
	     (aeq(k->path_kpts[ik].frac[0],k->breaks[6*ibreak+3]))&&
	     (aeq(k->path_kpts[ik].frac[1],k->breaks[6*ibreak+4]))&&
	     (aeq(k->path_kpts[ik].frac[2],k->breaks[6*ibreak+5]))){
	    ibreak++;
	    if (ibreak>=k->break_n) ibreak=-1;
	    fprintf(outfile,"\n");
	  }
	}
	fprintf(outfile,"%f %f\n",k_dist[ik],
		e->path_eval[ik*e->nspins*e->path_nbands+ns*e->path_nbands+nb]
		*scale);
      }
      fprintf(outfile,"end\n");
    }
  }

}

static int kaeq(double k1[3], double k2[3], char eq[3], int tr){
  int i,j,p[6][3]={{0,1,2},{1,2,0},{2,0,1},{1,0,2},{2,1,0},{0,2,1}},hit;

  //  fprintf(stderr,"kaeq: (%f, %f, %f)  (%f, %f, %f) %c%c%c\n",
  //	  k1[0],k1[1],k1[2],k2[0],k2[1],k2[2],eq[0],eq[1],eq[2]);
  
  for(j=0;j<6;j++){
    hit=1;
    for(i=0;i<3;i++)
      if (eq[i]!=eq[p[j][i]]) {hit=0; break;}
    if (!hit) continue;
    if (tr){
      hit=1;
      for(i=0;i<3;i++)
	if (!aeq(dist(k1[p[j][i]],k2[i]),0)) {hit=0; break;}
      if (hit) return 1;
      hit=1;
      for(i=0;i<3;i++)
	if (!aeq(dist(k1[p[j][i]],-k2[i]),0)) {hit=0; break;}
      if (hit) return 1;
    }
    else{
      hit=0;
      if ((aeq(k1[p[j][0]],k2[0]))&&(aeq(k1[p[j][1]],k2[1]))&&
	  (aeq(k1[p[j][2]],k2[2]))) hit=1;
      if (hit) return 1;
      hit=0;
      if ((aeq(k1[p[j][0]],-k2[0]))&&(aeq(k1[p[j][1]],-k2[1]))&&
	  (aeq(k1[p[j][2]],-k2[2]))) hit=1;
      if (hit) return 1;
    }
  }
  return 0;
}

static int kaeq_sym(double k1[3], double k2[3], struct unit_cell *c,
	     struct symmetry *s, int tr){
  struct atom a1,a2;
  int i,j,hit;

  if ((!s)||(s->n==0)||(!c)) return kaeq(k1,k2,"123",tr);

  init_atoms(&a1,1);
  init_atoms(&a2,1);
  
  for(i=0;i<s->n;i++){
    for(j=0;j<3;j++)
      a1.frac[j]=k2[j];
    addabs(&a1,1,c->recip);
    sym_vec(&a1,&a2,s->ops+i,c->basis,0);
    for(j=0;j<3;j++) a2.frac[j]=fmod(a2.frac[j],1.0);
    //for(j=0;j<3;j++) if (a2.frac[j]<-tol) a2.frac[j]+=1.0;
    if (debug>2){
      fprintf(stderr,"kaeq_sym: (%f, %f, %f)  (%f, %f, %f)\n",
	      k2[0],k2[1],k2[2],a2.frac[0],a2.frac[1],a2.frac[2]);
      ident_sym(s->ops+i,c,NULL,stderr);
    }

    if (tr){
      hit=1;
      for(j=0;j<3;j++)
	if (!aeq(dist(k1[j],a2.frac[j]),0)) {hit=0; break;}
    }
    else{
      hit=0;
      if ((aeq(k1[0],a2.frac[0]))&&(aeq(k1[1],a2.frac[1]))&&
	  (aeq(k1[2],a2.frac[2]))) hit=1;
    }
    if (hit) return 1;
  }
  return 0;
}

char *k_print(double kpt[3], struct special_sc *sp,
	      struct unit_cell *c, struct symmetry *s){
  int i,j;
  struct pt *p;
  static char buffer[100];
  
  if (vmod2(kpt)<0.01){
    return "G";
  }

  
  if (sp){
    /* First try without translation */
    p=sp->pts;
    while(*(p->l)){
      if (kaeq_sym(kpt,p->k,c,s,0)){
	return p->l;
      }
      p++;
    }
    /* Then with. Needed to distinguish FCC K and U, which are
     * translationally identical.
     */
    p=sp->pts;
    while(*(p->l)){
      if (kaeq_sym(kpt,p->k,c,s,1)){
	return p->l;
      }
      p++;
    }
  }

  buffer[0]=0;
  for(i=0;i<3;i++){
    if (fabs(kpt[i])<0.001)
      strcat(buffer,"0");
    else if (kpt[i]>0.999)
      strcat(buffer,"1");
    else{
      for(j=2;j<10;j++){
	if (aeq(j*kpt[i],floor(j*kpt[i]+0.5))){
	  sprintf(buffer+strlen(buffer),"%d/%d",(int)floor(j*kpt[i]+0.5),j);
	  break;
	}
      }
      if (j==10) sprintf(buffer+strlen(buffer),"%.3f",kpt[i]);
    }
    if (i!=2) strcat(buffer,",");
  }
  return buffer;
}

double tic(double range);
void yaxis_eps(FILE *outfile, double ymin, double ymax,
	       double ytic, int pt_h, char *title, int flag);

static void bands_plot_eps(FILE *outfile, struct unit_cell *c,
			   struct contents *m, struct kpts *k,
			   struct es *e, struct special_sc *spec,
			   struct symmetry *s,
			   double *k_dist, int *k_break, char *fmt){
  double ymin,ymax,ytic,x,range[2];
  char *label,*cptr,buffer[50];
  int i,pt_h,pt_w,ns,nb,ik,ibreak,ph=0,mono=0,dos=0;
  double scale;

  if (!strncmp(fmt,"--phbands_",10)) ph=1;
  if (strstr(fmt,"_mono")) mono=1;
  if (strstr(fmt,"_dos")) dos=1;
  
  if (flags&AU)
    scale=1/H_eV;
  else
    scale=1;
  
  /* Print header */
  
  pt_h=300;
  pt_w=500;

  ymin=1e99;
  ymax=-1e99;
  for(i=0;i<e->nspins*e->path_nbands*k->path_nkpt;i++){
    ymin=min(ymin,e->path_eval[i]);
    ymax=max(ymax,e->path_eval[i]);
  }

  ymin=ymin*scale;
  ymax=ymax*scale;
  ytic=tic(ymax-ymin);
  ymin=ytic*floor(ymin/ytic);
  ymax=ytic*ceil(ymax/ytic);

  fprintf(outfile,"%%!PS-Adobe-2.0 EPSF-2.0\n");
  fprintf(outfile,"%%%%BoundingBox: 0 0 %d %d\n",pt_w+100+dos*400,pt_h+100);
  fprintf(outfile,"%%%%LanguageLevel: 2\n"
          "%%%%EndComments\n\n"
          "/ctrshow { dup stringwidth pop -0.5 mul 0 rmoveto show } def\n"
          "/rshow { dup stringwidth pop neg 0 rmoveto show } def\n");

  fprintf(outfile,"%% Eight line colours\n"
	  "/LCs [[1 0 0] [0 1 0] [0 0 1] [1 0 1]\n"
	  "      [0 1 1] [1 0 0] [0 0 0] [1 0.3 0]] def\n");
  
  fprintf(outfile,"50 50 translate\n");
  if (m->title) fprintf(outfile,"/Helvetica-Bold 14 selectfont\n"
			"%d %d moveto (%s) ctrshow\n",pt_w/2,pt_h+20,m->title);
  
  fprintf(outfile,"/Helvetica 12 selectfont\n");
  fprintf(outfile,"0 0 moveto 0 %d lineto %d %d lineto %d 0 lineto "
          "closepath stroke\n",pt_h,pt_w,pt_h,pt_w);
  
  yaxis_eps(outfile,ymin,ymax,ytic,pt_h,NULL,0);
  fprintf(outfile,"-35 %.1f moveto gsave 90 rotate\n",0.5*pt_h);
  if (ph)
    fprintf(outfile,"(w (cm-1)) stringwidth pop -0.5 mul 0 rmoveto\n"
	    "gsave /Symbol 12 selectfont (w) show\n"
	    "currentpoint grestore moveto ( (cm-1)) show "
	    "grestore\n");
  else{
    if (flags&AU)
      fprintf(outfile,"(E (Ha)) stringwidth pop -0.5 mul 0 rmoveto\n"
	      "gsave /Symbol 12 selectfont (e) show\n"
	      "currentpoint grestore moveto ( (Ha)) show "
	      "grestore\n");
    else
      fprintf(outfile,"(E (eV)) stringwidth pop -0.5 mul 0 rmoveto\n"
	      "gsave /Symbol 12 selectfont (e) show\n"
	      "currentpoint grestore moveto ( (eV)) show "
	      "grestore\n");
  }

  for(i=0;i<k->path_nkpt;i++){
    if (k_break[i]){
      x=pt_w*k_dist[i]/k_dist[k->path_nkpt-1];
      fprintf(outfile,"%.1f -17 moveto ",x);
      label=k_print(k->path_kpts[i].frac,spec,c,s);
      if (!strcmp(label,"G"))
	fprintf(outfile,"gsave /Symbol 12 selectfont (G) ctrshow grestore\n");
      else if (!strcmp(label,"Sigma"))
	fprintf(outfile,"gsave /Symbol 12 selectfont (S) ctrshow grestore\n");
      else if (!strcmp(label,"Sigma1"))
	fprintf(outfile,"gsave /Symbol 12 selectfont (S1) ctrshow grestore\n");
      else
	fprintf(outfile,"(%s) ctrshow\n",label);
      fprintf(outfile,
	      "0.5 setgray %.1f 0 moveto 0 %d rlineto stroke 0 setgray\n",
	      x,pt_h);
    }
  }

  if (e->e_fermi){
    fprintf(outfile,"%% E Fermi line\n"
	    "[2 5] 0 setdash\n");
    fprintf(outfile,"0 %.1f moveto %d 0 rlineto stroke\n",
	    pt_h*(*e->e_fermi*scale-ymin)/(ymax-ymin),pt_w);
    fprintf(outfile,"%d %.1f moveto (E) show gsave 0.6 dup scale "
	    "0 -6 rmoveto (F) show grestore\n",pt_w+3,
	    pt_h*(*e->e_fermi*scale-ymin)/(ymax-ymin)-5);
    fprintf(outfile,"[] 0 setdash\n");
  }

  for(ns=0;ns<e->nspins;ns++){
    if ((e->nspins>1)&&(mono))
      fprintf(outfile,"LCs %d get aload pop setrgbcolor\n",ns&7);
    for(nb=0;nb<e->path_nbands;nb++){
      ibreak=-1;
      if (k->break_n) ibreak=0;
      if (!mono)
	fprintf(outfile,"LCs %d get aload pop setrgbcolor\n",nb&7);
      fprintf(outfile,"0 %f moveto\n",
	      pt_h*(e->path_eval[ns*e->path_nbands+nb]*scale-ymin)/(ymax-ymin));
      for(ik=1;ik<k->path_nkpt;ik++){
	fprintf(outfile,"%f %f ",pt_w*k_dist[ik]/k_dist[k->path_nkpt-1],
		pt_h*(e->path_eval[ik*e->nspins*e->path_nbands+
				   ns*e->path_nbands+nb]*scale-ymin)/
		(ymax-ymin));
	if ((ik>0)&&(ibreak>-1)){
	  if((aeq(k->path_kpts[ik-1].frac[0],k->breaks[6*ibreak]))&&
	     (aeq(k->path_kpts[ik-1].frac[1],k->breaks[6*ibreak+1]))&&
	     (aeq(k->path_kpts[ik-1].frac[2],k->breaks[6*ibreak+2]))&&
	     (aeq(k->path_kpts[ik].frac[0],k->breaks[6*ibreak+3]))&&
	     (aeq(k->path_kpts[ik].frac[1],k->breaks[6*ibreak+4]))&&
	     (aeq(k->path_kpts[ik].frac[2],k->breaks[6*ibreak+5]))){
	    ibreak++;
	    if (ibreak>=k->break_n) ibreak=-1;
	    fprintf(outfile,"moveto\n");
	  }
	  else fprintf(outfile,"lineto\n");
	}
	else fprintf(outfile,"lineto\n");
      }
      fprintf(outfile,"stroke\n");
    }
  }

  if (dos){
    fprintf(outfile,"\n0 setgray %d %d translate\n\n",pt_w+50,-50);
    cptr=strstr(fmt,"_dos")+4;
    if (ph) strcpy(buffer,"--phdos_eps");
    else strcpy(buffer,"--dos_eps");
    if (*cptr)
      strncat(buffer,cptr,30);
    strcat(buffer,"r");
    range[0]=ymin;
    range[1]=ymax;
    dos_plot(outfile,c,m,k,NULL,e,buffer,range,NULL);
  }
  
}

int sp2k(char *label, struct special_sc *sp, double k[3]){
  int i;
  struct pt *p;

  if ((!strcmp(label,"G"))||(!strcasecmp(label,"gamma"))||
      (!strcasecmp(label,"\\gamma"))){
    k[0]=k[1]=k[2]=0;
    fprintf(stderr,"Found gamma\n");
    return 0;
  }

  if (!sp) return 1;
  p=sp->pts;
  
  while(*p->l){
    if (!strcmp(label,p->l)){
      for(i=0;i<3;i++) k[i]=p->k[i];
      fprintf(stderr,"Found non gamma\n");
      return 0;
    }
    p++;
  }

  return 1;

}
