#include "Detector.h"

#include <iostream>
#include <cmath>
using namespace std;

Detector::Detector()
 : scaleFactor_( 1.2 )
{

}


bool Detection::overlaps( const Detection& detection ) const {
	int dx = abs(x-detection.x);
	int dy = abs(y-detection.y);
	int radius = (size+detection.size)/2;
	return dx <= radius && dy <= radius;
}

void Detector::detect( const Image& input ) {
	detected_.clear();
	Image temp;
	temp.copy( input );
	detect( input, temp, 1 );
}

void Detector::addTemplate( const RatioTemplate& ratioTemplate, unsigned int group ) {
	if ( templates_.size() <= group ) {
		unsigned int extra = group+1-templates_.size();
		for ( unsigned int i = 0; i < extra; i++ ) {
			stdvector<const RatioTemplate*> groupVector;
			templates_.push_back( groupVector );
		}
	}
	templates_[group].push_back( &ratioTemplate );
};

void Detector::clearTemplates() {
	templates_.clear();
}

void Detector::detect( const Image& originalImage, Image& input, double scale ) {
	const int width = input.width();
	const int height = input.height();
	const int size = 20;
	const unsigned int numGroups = templates_.size();
	CvMat mat;
	CvRect rect;
	CvArr* image = input.image();
	rect.width = rect.height = size;
	for ( int i = 0; i < width-(size-1); i++ ) {
		rect.x=i;
		for ( int j = 0; j < height-(size-1); j++ ) {
			rect.y=j;
			cvGetSubRect( image, &mat, rect );
			//Image subimage = input.subimage( i, j, size, size );
			bool matched = false;
			for ( unsigned int k = 0; k < numGroups; k++ ) {
				const stdvector<const RatioTemplate*>& group = templates_[k];
				const unsigned int numTemplates = group.size();
				// all of the templates in a group must match
				for ( unsigned int l = 0; l < numTemplates; l++ ) {
					//if ( !(matched = group[l]->matches( subimage )) )
					//	break;
					if ( !(matched = group[l]->matches( &mat, size, size )) )
						break;
				}
				// groups are OR'ed together
				if ( matched )
					break;
			}
			
			
			if ( matched ) {
				Detection detection( (int)((i+size/2)*scale), (int)((j+size/2)*scale), (int)(size*scale), 1 );
				detected_.push_back( detection );
			}
		}
	}
	 
	scale  = scaleFactor_*scale;
	int newWidth  = (int)(originalImage.width()/scale);
	int newHeight = (int)(originalImage.height()/scale);
	if ( newWidth >= size && newHeight >= size ) {
		input.copy( originalImage );
		//input.blur();
		input.resize(newWidth, newHeight);
		//input.scalePyramidDown();
		detect( originalImage, input, scale );
	}
}

void Detector::coalesceDetections( int width, int height, int neighbourhood, unsigned int threshold ) {
	stdvector<Detection> coalesced;
	
	int radius = 10;
	for ( int i = radius; i < width-radius; i++ ) {
		for ( int j = radius; j < height-radius; j++ ) {
			stdvector<Detection> neighbours;
			for ( unsigned int k = 0; k < detected_.size(); k++ ) {
				Detection detected = detected_[ k ];
				int dx = detected.x-i;
				int dy = detected.y-j;
				if ( dx*dx + dy*dy < neighbourhood*neighbourhood ) {
					neighbours.push_back( detected );
				}
			}
			if ( neighbours.size() >= threshold ) {
				int momentx = 0;
				int momenty = 0;
				int moment  = 0;
				for ( unsigned int k = 0; k < neighbours.size(); k++ ) {
					Detection detected = neighbours[ k ];
					moment += detected.size;
					momentx += (detected.size*detected.x);
					momenty += (detected.size*detected.y);
				}
				Detection avg( momentx/moment, momenty/moment, moment/neighbours.size(), neighbours.size() );
				coalesced.push_back( avg );
			}
		}
	}
	
	
	detected_.clear();
	
	// now deal with overlapping
	for ( unsigned int i = 0; i < coalesced.size(); i++ ) {
		Detection di = coalesced[ i ];
		bool add = true;
		for ( unsigned int j = 0; j < coalesced.size(); j++ ) {
			if ( i == j ) continue;
			Detection dj = coalesced[ j ];
			if ( di.overlaps( dj ) && (di.count < dj.count || (di.count == dj.count && i > j)) ) {
				add = false;
				break;
			}
		}
		if ( add )
			detected_.push_back( di );
	}
	
}

void Detector::groupOverlapping() {
	bool overlapping = false;
	
	do {
		overlapping = false;
		//stdvector<Detection> coalesced;
		for ( stdvector<Detection>::iterator i = detected_.begin(); i != detected_.end(); i++ ) {
			Detection& di = *i;
			stdvector<Detection>::iterator closest;
			int minDistSq = INT_MAX;
			for ( stdvector<Detection>::iterator j = i; j != detected_.end(); j++ ) {
				Detection& dj = *j;
				if ( i != j ) {
					int dx = di.x-dj.x, dy = di.y-dj.y;
					int distSq = dx*dx + dy*dy;
					int radius=(di.size+dj.size)/2;
					if ( distSq < (radius*radius) ) {
						overlapping=true;
						if ( distSq < minDistSq ) {
							closest = j;
							minDistSq = distSq;
						}
					}
				}
			}
			if ( overlapping ) {
				// biased towards detections that have already
				// been grouped
				Detection& dj = *closest;
				int count = di.count+dj.count;
				int size = (di.count*di.size+dj.count*dj.size)/count;
				int x = (di.count*di.x+dj.count*dj.x)/count;
				int y = (di.count*di.y+dj.count*dj.y)/count;
				Detection d( x, y, size, count );
				detected_.erase(closest);
				detected_.erase(i);
				detected_.push_back(d);
				break;
			}
			
		}
		//detected_ = coalesced;
		/*detected_.clear();
		for ( unsigned int i = 0; i < coalesced.size(); i++ ) {
			Detection& d = coalesced[i];
			detected_.push_back(d);
		}*/
	}
	while( overlapping );
	
	
}

void Detector::drawDetections( Image& image ) const {
	for ( unsigned int i = 0; i < detected_.size(); i++ ) {
		const Detection& detection = detected_[ i ];
		int x1=detection.x-detection.size/2;
		int y1=detection.y-detection.size/2;
		int x2=detection.x+detection.size/2;
		int y2=detection.y+detection.size/2;
		image.rect(255.0,x1,y1,x2,y2);
		image.rect(0.0,x1-1,y1-1,x2+1,y2+1);
	}
}
