#ifndef _itkNMFBase_h #define _itkNMFBase_h #include #include #include #include #include #include #include #include #include "vnl/vnl_matrix.h" #include "vnl/vnl_sample.h" #include "vnl/vnl_vector.h" #include "itkExceptionObject.h" namespace itk { /** \class NMF Base * \brief Base class for non-negative matrix factorization (NMF) * based dimentionality reduction/clustering object * * itkNMFbase provides an abstract base class to implement non-negative matrix * factorization (NMF) based dimensionality reduction on a data * matrix (V). In the case of an image set with same dimensions, each row of * the matrix can be an image and the number of different images in an image * set would define the matrix. * * For a given choice of matrix V and the number of desired components in * the decomposition, two matrices W and H are generated such that V = W * H, * where elements of both W, H are positive. If V is a n * m matrix and * the user request for r decompositons, where r is genrally chosen such that * (n+m)r < nm, then the two resulting matrices W and H have dimensions * n * r and r * m, respectively. * * For more information about the algorithm, see D. Lee and H. Seung * ``Learning the parts of objects by non-negative matrix factorization'' * {\em Nature}, vol. 40, pp. 788-791, 1999. * * The user specifies entire data set of images as a matrix. In addition, * the user provies the number of desired decompositions. Furthermore, the * user has the option to choose a default initialization using a uniform * random number generator between 0 and 1, or provide their own * initialization matrix. * * Three virtual functions (Compute, NMFUpdate and TestConvergence) are * defined that needs to be implemented in the base classes. While the Compute * function is the main function that ensures the data in the matrix follow * all the restrictions imposed by the implemented NMF method, the NMFUpdate * provides with the specific update equations for a given choice of * metric in a decomposition. TestConvergence function provides a means * to detect if the decomposed matrices have stabilized between the * iterations. * * \ingroup ImageFeatureExtraction */ template< class T > class NMFBase: public ExceptionObject { public: typedef vnl_matrix InMatrixType; typedef vnl_matrix OutMatrixType; /** Set the input matrix V */ void SetInput(InMatrixType &inmat) { m_V = &inmat; } InMatrixType* GetInput() { return m_V; } /** Set/Get the number of classes to decompose; default is 2. */ void SetNumberOfClasses( unsigned int numClass ) { m_NumberOfClasses = numClass; } unsigned int GetNumberOfClasses() { return m_NumberOfClasses; } /** Initialization for the decompositon * The default uses a uniform random number between 0-1 */ virtual void Initialize(); /** User specified initialization */ virtual void Initialize(OutMatrixType *w, OutMatrixType *h); /** Set/Get the output dompostion for the W matrix */ void SetWMatrix( const OutMatrixType &w) { m_W = w; } OutMatrixType* GetWMatrix() { return &m_W; } /** Set/Get the output dompostion for the H matrix */ void SetHMatrix( const OutMatrixType &h) { m_H = h; } OutMatrixType* GetHMatrix() { return &m_H; } /** Main function to call the NMF factorization code */ virtual void Compute()=0; /** Main function implementing the NMF updates * for a given choice of metric */ virtual void NMFUpdate()=0; /** Write the NMF Matrix to a file as a binary stream */ int Write(OutMatrixType *m, std::string& str); /** Default implementation to test convergence for matrix H */ virtual unsigned int TestConvergence() = 0; /** Constructor */ NMFBase(); /** Desctructor */ virtual ~NMFBase() throw(); private: InMatrixType *m_V; OutMatrixType m_H; OutMatrixType m_W; unsigned int m_NumberOfClasses; }; // class itkNMFbase } // namespace itk #ifndef ITK_MANUAL_INSTANTIATION #include "itkNMFBase.txx" #endif #endif