BaatzAlgorithm.txx 9.27 KB
/*=========================================================================

  Program:   Large Scale Segmentation (LSS)
  Language:  C++
  author:    Lassalle Pierre



  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef __BaatzAlgorithm_txx
#define __BaatzAlgorithm_txx

template<class TInputImage>
void
BaatzAlgorithm<TInputImage>::SetDimensionForEncoder(unsigned int cols, unsigned int rows)
{
	m_RMHandler.m_Encoder.SetCols(cols);
	m_RMHandler.m_Encoder.SetRows(rows);
}

template<class TInputImage>
void
BaatzAlgorithm<TInputImage>::InitFromTiffImage(const std::string& input_file)
{	
	auto reader = ReaderType::New();
	reader->SetFileName(input_file);
	reader->Update();

	auto regionToProcess = reader->GetOutput()->GetLargestPossibleRegion();
	unsigned int cols = regionToProcess.GetSize()[0];
	unsigned int rows = regionToProcess.GetSize()[1];
	unsigned int bands = reader->GetOutput()->GetNumberOfComponentsPerPixel();

	m_RMHandler.m_Encoder.SetRows(rows); // Penser à n'utiliser qu'une seule méthode (SetRegionToProcess)
	m_RMHandler.m_Encoder.SetCols(cols);

	m_RegionList.reserve(cols*rows);

	IteratorType it(reader->GetOutput(), regionToProcess);

	// Create the vertices
	long unsigned int idx = 0;
	for(it.GoToBegin(); !it.IsAtEnd(); ++it)
	{
		auto r = std::make_shared<BaatzRegion>();
		m_RMHandler.InitGenericRegion(r, idx);
		r->m_Area = 1;
		r->m_Perimeter = 4;
		r->m_Avg_Color.reserve(bands);
		r->m_Avg_Color_Square.reserve(bands);
		r->m_Color_Sum.reserve(bands);
		r->m_Std_Color.reserve(bands);

		for(unsigned int b = 0; b < bands; b++)
		{
			r->m_Avg_Color.push_back(it.Get()[b]);
			r->m_Avg_Color_Square.push_back(pow(it.Get()[b], 2));
			r->m_Color_Sum.push_back(it.Get()[b]);
			r->m_Std_Color.push_back(0);
		}

		m_RegionList.push_back(r);
		++idx;
	}
	m_RMHandler.InitNeighborhood(m_RegionList);
}

template<class TInputImage>
double 
BaatzAlgorithm<TInputImage>::ColorComponentCostFusion(RegionPointerType r1, RegionPointerType r2)
{
	unsigned int bands = r1->m_Avg_Color.size();
 	float mean[bands], colorSum[bands];
	float squarePixels[bands];
	float stddev[bands];
	float stddevNew[bands];
	double color_f[bands];
	double color_h;
	float a_current, a_neighbor, a_sum;
	a_current = r1->m_Area;
	a_neighbor = r2->m_Area;
	a_sum = a_current+a_neighbor;

	for (unsigned int b = 0; b < bands; b++)
	{
		mean[b] = ((r1->m_Avg_Color[b]*a_current)+(r2->m_Avg_Color[b]*a_neighbor))/a_sum;
		squarePixels[b] = (r1->m_Avg_Color_Square[b])+(r2->m_Avg_Color_Square[b]);
		colorSum[b] = r1->m_Color_Sum[b] + r2->m_Color_Sum[b];
		stddev[b] = 0;
		stddevNew[b] = 0;
	}

	for(unsigned int b = 0; b < bands; b++)
	{
		stddevNew[b] = squarePixels[b] - 2*mean[b]*colorSum[b] + a_sum*mean[b]*mean[b];
	}

	/* calculates color factor per band and total */
	color_h = 0;
	for (unsigned int b = 0; b < bands; b++)
	{
		stddev[b] = sqrt(stddevNew[b]/a_sum);
		color_f[b] = (a_current*r1->m_Std_Color[b]) + (a_neighbor*r2->m_Std_Color[b]);
		color_f[b] = (a_sum*stddev[b])- color_f[b];
		color_h += color_f[b];
	}
	return color_h;
}

template<class TInputImage>
double 
BaatzAlgorithm<TInputImage>::CompactnessComponentCostFusion(RegionPointerType r1, RegionPointerType r2)
{
	double spatial_h, smooth_f, compact_f;
	float area[3], perimeter[3], b_box_len[3]; /* 0-current segment; 1-neighbor segment; 2-merged (new) segment */

	/* area */
	area[0] = r1->m_Area;
	area[1] = r2->m_Area;
	area[2] = area[0]+area[1];

	/* perimeter */
	perimeter[0] = r1->m_Perimeter;
	perimeter[1] = r2->m_Perimeter;
	perimeter[2] = r1->m_Perimeter + r2->m_Perimeter - 2*m_RMHandler.GetConnections(r1, r2);

	/* bounding box lenght */
	auto mbbox = m_RMHandler.GetResultingBbox(r1, r2);
	b_box_len[0] = (r1->m_Bbox.GetSize(0))*2 + (r1->m_Bbox.GetSize(1))*2;
	b_box_len[1] = (r2->m_Bbox.GetSize(0))*2 + (r2->m_Bbox.GetSize(1))*2;
	b_box_len[2] = (mbbox.GetSize(0))*2 + (mbbox.GetSize(1))*2;

	/* smoothness factor */
	smooth_f = (area[2]*perimeter[2]/b_box_len[2] - 
				(area[1]*perimeter[1]/b_box_len[1] + area[0]*perimeter[0]/b_box_len[0]));

	/* compactness factor */
	compact_f = (area[2]*perimeter[2]/sqrt(area[2]) - 
				(area[1]*perimeter[1]/sqrt(area[1]) + area[0]*perimeter[0]/sqrt(area[0])));

	/* spatial heterogeneity */
	spatial_h = m_Parameters.m_CompactnessWeight*compact_f + (1-m_Parameters.m_CompactnessWeight)*smooth_f;

	return spatial_h;
}

template<class TInputImage>
void
BaatzAlgorithm<TInputImage>::Segmentation()
{
	m_PreviousMerge = false;
	unsigned int max_iter = 500;
	if(m_NumberOfIterations > 0)
		max_iter = m_NumberOfIterations;

	unsigned int curr_step = 0;
	double cost;
	RegionPointerType bestneigh = nullptr;
	bool prev_merged = true;

	auto cw = m_Parameters.m_ColorWeight;
	auto sc = m_Parameters.m_Scale;
	m_ComputeCostFunction = [&](RegionPointerType r1, RegionPointerType r2)->double
	{
		double spectral_h;
		double spatial_h;
		double cost = 0;

		spectral_h = ColorComponentCostFusion(r1, r2);
		cost += cw * spectral_h;

		if(cost < sc)
		{
			spatial_h = CompactnessComponentCostFusion(r1, r2);
			cost += (1-cw)*spatial_h;
			return cost;
		}
		else
			return cost;
	};

	while(curr_step < max_iter && prev_merged)
	{
		prev_merged = false;
		std::cout << curr_step << std::endl; 
		m_RMHandler.UpdateMergingCost(m_RegionList, m_ComputeCostFunction);

		for(auto& r: m_RegionList)
		{
			if(m_RMHandler.IsLMBF(r, m_Parameters.m_Scale))
			{
				auto res = m_RMHandler.GetRegionsToMerge(r);
				//std::cout << "Merge " << res.first->m_Id << " with " << res.second->m_Id << std::endl;
				// User contribution
				UpdateAttribute(res.first, res.second);
				m_RMHandler.Update(res);
				prev_merged = true;
			}
		}
		m_RMHandler.RemoveExpiredVertices(m_RegionList);
		curr_step++;
		m_PreviousMerge = prev_merged;
	}
}

template<class TInputImage>
void
BaatzAlgorithm<TInputImage>::UpdateAttribute(RegionPointerType r1, RegionPointerType r2)
{
	// Update the spectral attributes

	unsigned int bands = r1->m_Avg_Color.size();
 	float mean[bands], colorSum[bands];
	float squarePixels[bands];
	float a_current, a_neighbor, a_sum;
	a_current = r1->m_Area;
	a_neighbor = r2->m_Area;
	a_sum = a_current+a_neighbor;

	for (unsigned int b = 0; b < bands; b++)
	{
		mean[b] = ((r1->m_Avg_Color[b]*a_current)+(r2->m_Avg_Color[b]*a_neighbor))/a_sum;
		squarePixels[b] = (r1->m_Avg_Color_Square[b])+(r2->m_Avg_Color_Square[b]);
		colorSum[b] = r1->m_Color_Sum[b] + r2->m_Color_Sum[b];
	}

	for(unsigned int b = 0; b < bands; b++)
	{
		r1->m_Avg_Color[b] = mean[b];
		r1->m_Avg_Color_Square[b] = squarePixels[b];
		r1->m_Color_Sum[b] = colorSum[b];
		r1->m_Std_Color[b] = sqrt((squarePixels[b] - 2*mean[b]*colorSum[b] + a_sum*mean[b]*mean[b])/a_sum);
	}

	// Update spatial attributes
	r1->m_Area += r2->m_Area;
	r1->m_Perimeter += r2->m_Perimeter - 2*m_RMHandler.GetConnections(r1,r2);
}

template<class TInputImage>
void
BaatzAlgorithm<TInputImage>::WriteLabelImage(const std::string& ofname)
{
	typename InputImageType::IndexType index;
	typename InputImageType::SizeType size;
	typename InputImageType::RegionType region;
	auto label_image = LabelImageType::New();

	index[0] = 0;
	index[1] = 0;
	size[0] = m_RMHandler.m_Encoder.GetCols();
	size[1] = m_RMHandler.m_Encoder.GetRows();
	region.SetIndex(index);
	region.SetSize(size);

	label_image->SetRegions(region);
	label_image->Allocate();

	unsigned int label = 1;
	for(auto& v : m_RegionList)
	{
		auto pixels = m_RMHandler.m_Encoder.GenerateAllPixels(v->m_Id, v->m_Contour, v->m_Bbox);
		for(auto& pix: pixels)
		{
			index[0] = pix % size[0];
			index[1] = pix / size[0];
			auto label_pixel = label_image->GetPixel(index);
			label_image->SetPixel(index, label);
		}
		label++;
	}

	auto label_writer = LabelWriterType::New();
	label_writer->SetFileName(ofname);
	label_writer->SetInput(label_image);
	label_writer->Update();
}

template<class TInputImage>
void
BaatzAlgorithm<TInputImage>::WriteResultingImage(const std::string& ofname)
{
	auto out_image = InputImageType::New();
	typename InputImageType::IndexType index;
	typename InputImageType::SizeType size;
	typename InputImageType::RegionType region;

	index[0] = 0;
	index[1] = 0;
	size[0] = m_RMHandler.m_Encoder.GetCols();
	size[1] = m_RMHandler.m_Encoder.GetRows();
	region.SetIndex(index);
	region.SetSize(size);
	out_image->SetRegions(region);
	out_image->SetNumberOfComponentsPerPixel(3);
	out_image->Allocate();

	for(unsigned int r = 0; r< size[1]; r++)
	{
		for(unsigned int c = 0; c<size[0]; c++)
		{
			index[0] = c;
			index[1] = r;
			auto pixel_value = out_image->GetPixel(index);
			for(int b=0; b<3; b++)
				pixel_value[b] = 0;
			out_image->SetPixel(index, pixel_value); 
		}
	}
	unsigned int label = 1;
	for(auto& v : m_RegionList)
	{
		auto pixels = m_RMHandler.m_Encoder.GeneratePixels(v->m_Id, v->m_Contour);
		for(auto& pix: pixels)
		{
			index[0] = pix % size[0];
			index[1] = pix / size[0];
			auto pixel_value = out_image->GetPixel(index);
			for(int b=0; b<3; b++)
				pixel_value[b] = v->m_Avg_Color[b];
			out_image->SetPixel(index, pixel_value);
		}
	}

	auto writer = WriterType::New();
	writer->SetFileName(ofname);
	writer->SetInput(out_image);
	writer->Update();
}

#endif