import sys.path
from java.lang.System import getProperty
sys.path.append(getProperty('fiji.dir') + '/macros/jay_python')
from ij.text import TextWindow,TextPanel
from ij import IJ,WindowManager
import math
import table_tools as tt

orig=WindowManager.getCurrentImage().getTitle()
print(orig)
IJ.runMacro("Stack.setChannel(2);")
IJ.run("track max not mask 3D jru v1", "threshhold=Max z_edge_buffer=1 z_ratio=3.125 min_separation=17 thresh_fraction=0.2 xy_edge_buffer=10 max_blobs=1000 display_frame=1 display_slice=19 link_range=10.00000 max_link_delay=1 min_traj_length=0 min_separation_z=6");
#IJ.runMacro("close();")
WindowManager.getImage('Trajectories').close()
tabname='Traj Data' #name of our tracking table
zratio=3.125 #z spacing/xy spacing
distthresh=20.0 #distance threshold for a point pair
zdistthresh=10.0 #z distance threshold for a point pair

sdata,headers=tt.getTableData(tabname)
rows=len(sdata)
cols=len(sdata[0])
fdata=tt.transpose(tt.getTableValues(sdata))
print(headers)
xc=headers.index('x')
yc=headers.index('y')
zc=headers.index('z')
print('xyz columns')
print([xc,yc,zc])
#now search through and find the closest and next closest distance neighbor for each spot
def closestDist(idx,xarr,yarr,zarr,zratio):
	zr2=zratio*zratio
	x1=xarr[idx]
	y1=yarr[idx]
	z1=zarr[idx]
	dist2=[(x1-xarr[i])*(x1-xarr[i])+(y1-yarr[i])*(y1-yarr[i])+zr2*(z1-zarr[i])*(z1-zarr[i]) for i in range(len(xarr))]
	zdists=[zratio*abs(z1-zarr[i]) for i in range(len(zarr))]
	#set the self-distance to a high value to avoid it
	dist2[idx]=1000.0
	mindist2=min(dist2)
	minidx=dist2.index(mindist2)
	zdist=zdists[minidx]
	del dist2[minidx]
	mindist22=min(dist2)
	minidx2=dist2.index(mindist22)
	if(minidx2>=minidx):
		minidx2=minidx2+1
	return math.sqrt(mindist2),minidx,math.sqrt(mindist22),minidx2,zdist

#do that operation for all of the spots
mindists=[closestDist(i,fdata[xc],fdata[yc],fdata[zc],zratio) for i in range(len(fdata[xc]))]
#start by finding the single spots (no neighbors)
#include values where zdist is > zdistthresh (oriented incorrectly)
singles=[(mindists[i][0]>distthresh or mindists[i][4]>zdistthresh)  for i in range(len(mindists))]

#now find the triple centers (two neighbors)
triplecs=[(mindists[i][0]<=distthresh and mindists[i][2]<=distthresh) for i in range(len(mindists))]
#now mark all of the triple neighbors as well
triplens=[False]*len(mindists)
for i in range(len(triplecs)):
	if(triplecs[i]):
		triplens[mindists[i][1]]=True
		triplens[mindists[i][3]]=True
#everything else should be doubles, make a big drop list
droplist=[(singles[i] or triplecs[i] or triplens[i]) for i in range(len(mindists))]
print('n singles: '+str(sum(singles)))
print('n triple centers: '+str(sum(triplecs)))
print('n drops: '+str(sum(droplist)))
#drop the lines and make a new table
#add neighbors after one another
newsdata=[]
picked=[False]*len(droplist)
for i in range(len(droplist)):
	if((not droplist[i]) and (not picked[i])):
		neighbor=mindists[i][1]
		newsdata.append(sdata[i])
		newsdata.append(sdata[neighbor])
		picked[i]=True
		picked[neighbor]=True
		
tt.makeTableFromStrings(newsdata,headers,"filtered_doubles")
IJ.run("plot columns jru v1", "windows=filtered_doubles x_vals z_vals separate x_column=x y_column=y z_column=z")
#IJ.run("set 3D shapes jru v1", "shape=square color=black")
IJ.run("traj 2 roi jru v1", "image="+orig+" trajectory=[filtered_doubles Plot] roi=Roi_Manager square=10 z_ratio=1");
WindowManager.getImage('filtered_doubles Plot').close()