#ifndef _itkKullbackLeiblerNMF_txx #define _itkKullbackLeiblerNMF_txx #include #include "itkKullbackLeiblerNMF.h" namespace itk { //------------------------------------------------------------------------ //KullbackLeiblerNMF //------------------------------------------------------------------------ template KullbackLeiblerNMF ::KullbackLeiblerNMF():m_Eps(2.2204e-16), m_ConvergenceTestStepSize(10), m_Iter(5000), m_Stop(5), m_Verbose(false) { } //------------------------------------------------------------------------ //~KullbackLeiblerNMF //------------------------------------------------------------------------ template KullbackLeiblerNMF ::~KullbackLeiblerNMF() throw() { } //-------------------------------------------------------------------------- // NMF Update //------------------------------------------------------------------------- template void KullbackLeiblerNMF::NMFUpdate() { vnl_vector x1(this->GetNumberOfClasses(),0); vnl_vector x2(this->GetNumberOfClasses(),0); InMatrixType *v = this->GetInput(); OutMatrixType *w = this->GetWMatrix(); OutMatrixType *h = this->GetHMatrix(); //First perform the H matrix update for(unsigned int i=0; iGetNumberOfClasses(); i++) { x1[i] = w->get_column(i).sum(); } vnl_matrix estV = (*w) * (*h); //We are overwriting the memory used by estV to compute (origV/estV) for(unsigned int i=0; iinplace_transpose(); vnl_matrix estH = ((*w) * estV); for(unsigned int i=0; iSetHMatrix( estH ); //Now perform the w matrix update for(unsigned int i=0; iGetNumberOfClasses(); i++) { x2[i] = estH.get_row(i).sum(); } //untransposing m_W w->inplace_transpose(); estV = (*w) * estH; //We are overwriting the memory used by estV to compute (origV/estV) for(unsigned int i=0; i estW = (estV * estH.transpose()); estH.inplace_transpose(); vnl_matrix estW = (estV * estH); for(unsigned int i=0; iSetWMatrix( estW ); for(typename OutMatrixType::iterator it = w->begin(); it < w->end(); it++) { if( *it < m_Eps ) *it = m_Eps; } for(typename OutMatrixType::iterator it = h->begin(); it < h->end(); it++) { if( *it < m_Eps ) *it = m_Eps; } }//end NMF Update //-------------------------------------------------------------------------- // Compute //------------------------------------------------------------------------- template void KullbackLeiblerNMF::Compute() { //Do bounds check for the input matrix validity InMatrixType *v = this->GetInput(); double minVal = v->min_value(); if( minVal < 0) { itk::ExceptionObject exception(__FILE__, __LINE__); exception.SetDescription("Matrix elements cannot be negative."); throw exception; } for(unsigned int i=0; irows(); i++) { if( v->get_row(i).sum() <= 0 ) { std::cout << "All elements 0 in row " << i << std::endl; std::cout << v->get(i,0) << "\t" << v->get(i, 10) << std::endl; itk::ExceptionObject exception(__FILE__, __LINE__); exception.SetDescription("Not all entries in a row can be zero."); throw exception; break; } } //Check initialization OutMatrixType *w = this->GetWMatrix(); OutMatrixType *h = this->GetHMatrix(); if( (w->columns() != h->rows()) || (w->rows() != v->rows()) || (h->columns() != v->columns()) ) { itk::ExceptionObject exception(__FILE__, __LINE__); exception.SetDescription("Please check initialization."); throw exception; } //create the containers to test convergence m_Label.set_size(h->columns()); m_Label.fill(-1); m_LabelOld.set_size(h->columns()); m_LabelOld.fill(-1); //Perform the decomposition unsigned int inc =0; for(unsigned int idx=0; idxNMFUpdate(); if( (idx % m_ConvergenceTestStepSize) == 0) { unsigned int count = this->TestConvergence(); if(m_Verbose) { std::cout << "Iteration no : " << idx << "\t" << " Num. Mismatch : " << count; std::cout << "\t" << "No change count : " << inc << std::endl; } //test termination of iterations if(count == 0) { inc++; } else { inc =0; } if( inc > m_Stop ) { break; } } } }//end Compute //-------------------------------------------------------------------------- // TestConvergence //------------------------------------------------------------------------- template unsigned int KullbackLeiblerNMF::TestConvergence() { OutMatrixType *h = this->GetHMatrix(); m_Label.set_size(h->columns()); m_Label.fill(0); //Find labels in the H matrix based on a maximum liklihood method for( unsigned int i=0; i< h->columns(); i++) { unsigned int maxidx =0; double maxval= h->get(0,i); for( unsigned int j=1; j< h->rows(); j++) { if (h->get(j,i) > maxval) { maxval = h->get(j,i); maxidx = j; } } if( static_cast(m_Label(i)) != maxidx ) { m_Label(i) = maxidx; } } int count=0; //Check how many labels changed between iterations for(unsigned int i=0; i typename KullbackLeiblerNMF::MLELabelVectorType* KullbackLeiblerNMF::GetMLELabel(OutMatrixType *m, unsigned int e) { unsigned int vecSize; switch(e) { case 1: vecSize = m->rows(); break; case 2: vecSize = m->columns(); break; default: std::cout << "No defaults provided; use either 1 or 0." << std::endl; itk::ExceptionObject exception(__FILE__, __LINE__); exception.SetDescription("No defaults provided; use either 1 or 2."); throw exception; } m_MLELabel.set_size(vecSize); m_MLELabel.fill(0); switch(e) { //Process the data in a row major sense case 1: for( unsigned int i=0; i< m->rows(); i++) { unsigned int maxidx =0; double maxval= m->get(i,0); for( unsigned int j=1; j< m->columns(); j++) { if (m->get(i,j) > maxval) { maxval = m->get(i,j); maxidx = j; } } if( m_MLELabel(i) != maxidx ) { m_MLELabel(i) = maxidx; } } break; //end case 1 //Process the data column major wise case 2: for( unsigned int i=0; i< m->columns(); i++) { unsigned int maxidx =0; double maxval= m->get(0,i); for( unsigned int j=1; j< m->rows(); j++) { if (m->get(j,i) > maxval) { maxval = m->get(j,i); maxidx = j; } } if( m_MLELabel(i) != maxidx ) { m_MLELabel(i) = maxidx; } } break; //end case 1 }//end switch m_MLELabel += 1; return &m_MLELabel; } //------------------------------------------------------------------------ // WriteVector Output //------------------------------------------------------------------------ template int KullbackLeiblerNMF::WriteMLE(vnl_vector *m, std::string &outFile) { try { std::ofstream output(outFile.c_str(), std::ios::out | std::ios::binary); output.write(reinterpret_cast (&m[0]), m->size() *sizeof(unsigned int)); if (output.fail()) throw std::runtime_error("Error writing the matrix"); } catch (const std::exception &e) { std::cout << e.what() << std::endl; exit(-1); } catch (...) { std::cout << "Module failed for unknown reason. Investigation needed." << std::endl; return -1; } return 0; }//end WriteVector } // namespace itk #endif