Listing 3: Multithreading word count program

//
// partial listing of mtwc.h
//
const int NUM_WORKER_THREADS = 0 ;    // use default, 2 * #CPUs + 2
const int THREAD_CONCURRENCY = 0 ;    // use default, 1 per CPU
const int READ_BUFFER_SIZE     = 4096 ;

class CWordCountContext {
public:
    CWordCountContext();

    enum enState {
        STATE_UNOPENED,
        STATE_OPENED,
        STATE_CLOSED,
        STATE_ERROR_OPENING,
        STATE_ERROR_READING
    };

    LPSTR       lpszFilename ;
    HANDLE      hFile ;
    OVERLAPPED  overlapped ;
    char        buffer[ READ_BUFFER_SIZE ];
    enState     state;
    int         spaceState ;
    DWORD       dwNumChars ;
    DWORD       dwNumWords ;
    DWORD       dwNumLines ;
};


class CWordCountThreadPool : public CThreadPool {
public:

    CWordCountThreadPool( DWORD numThreads, DWORD threadConcurrency, 
        DWORD numFiles );
    ~CWordCountThreadPool();

    CWorkerThread* CreateWorkerThread( CThreadPool* pPool );
    void WaitUntilDone();
    void ProcessedFile();

protected:
    long    m_NumFiles ;
    HANDLE    m_hDone ;
};


class CWordCountWorker : public CWorkerThread {
public:

    CWordCountWorker( CWordCountThreadPool* pPool ) :
        CWorkerThread( pPool ) {}

    void OnReceivedCompletionPacket( BOOL bResult,
        DWORD dwNumberOfBytesTransferred, DWORD dwKey,
        LPOVERLAPPED lpOverlapped );

protected:
    void readFile( CWordCountContext* aCtx );
    void closeFile( CWordCountContext* aCtx );
};

/ -------------------------------------------------------------------
// Partial listing of mtwc.cpp

void CWordCountWorker::OnReceivedCompletionPacket( 
    BOOL bResult, DWORD dwNumberOfBytesTransferred, DWORD dwKey,
    LPOVERLAPPED lpOverlapped )
{
    CWordCountContext* pContext = 
        reinterpret_cast<CWordCountContext*>(dwKey);

    if( pContext->state == CWordCountContext::STATE_UNOPENED ) {
        pContext->hFile = CreateFile( pContext->lpszFilename, 
            GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING,
            FILE_FLAG_OVERLAPPED | FILE_FLAG_SEQUENTIAL_SCAN, NULL );

        if( pContext->hFile == INVALID_HANDLE_VALUE ) {
            pContext->state = CWordCountContext::STATE_ERROR_OPENING;
            closeFile( pContext );
        } else {
            // Associate overlapped operations on this file 
            // with our I/O completion port
            m_pThreadPool->AssociateFile( pContext->hFile, 
                reinterpret_cast<DWORD>(pContext) );
            pContext->state = CWordCountContext::STATE_OPENED ;
            readFile( pContext ); // kick off first read...
        } // else
    } else if( pContext->state == CWordCountContext::STATE_OPENED ) {
        if( dwNumberOfBytesTransferred == 0 ) {
            // we are at end of file
            closeFile( pContext );
        } else {
            // This is the guts of the word-counting...
            for( DWORD i=0; i<dwNumberOfBytesTransferred; i++ ) {
               char c = pContext->buffer[i];
               if( c == '\n' ) {
                    pContext->dwNumLines++ ;
                } else if( isspace( c ) ) {
                    pContext->spaceState = 0;
                } else if( pContext->spaceState == 0 ) {
                    pContext->spaceState = 1 ;
                    pContext->dwNumWords++ ;
                } // else
            } // for

            pContext->dwNumChars += dwNumberOfBytesTransferred ;
            lpOverlapped->Offset += dwNumberOfBytesTransferred ;
            readFile( pContext );
        } // else
    } else {
        // unknown state--should never happen
        assert(0);
    } // else
}



void CWordCountWorker::readFile( CWordCountContext* pContext )
{
    BOOL bResult = ReadFile( pContext->hFile, &(pContext->buffer),
        READ_BUFFER_SIZE, NULL, &(pContext->overlapped) );
    if( bResult == FALSE ) {
        DWORD dwLastError = GetLastError();
        if( dwLastError == ERROR_IO_PENDING ) {
            // asynchronous read was queued, this is normal result...
        } else if( dwLastError == ERROR_HANDLE_EOF ) {
            closeFile( pContext );
        } else {
            pContext->state = CWordCountContext::STATE_ERROR_READING;
            closeFile( pContext );
        } // else
    } // if
}


void CWordCountWorker::closeFile( CWordCountContext* pContext )
{
    pContext->state = CWordCountContext::STATE_CLOSED ;
    if( pContext->hFile != INVALID_HANDLE_VALUE ) {
        CloseHandle( pContext->hFile );
    } // if
    dynamic_cast<CWordCountThreadPool*>(
        m_pThreadPool)->ProcessedFile();
}


CWordCountThreadPool::CWordCountThreadPool( DWORD numThreads,
    DWORD threadConcurrency, DWORD numFiles ) :
        CThreadPool( numThreads, threadConcurrency ), 
        m_NumFiles( numFiles )
{
    assert( numFiles > 0 );
    m_hDone = CreateEvent( NULL, TRUE, FALSE, NULL );
}


void CWordCountThreadPool::ProcessedFile()
{
    InterlockedDecrement( &m_NumFiles );
    if( m_NumFiles == 0 ) {
        SetEvent( m_hDone );
    } // if
}


int main( int argc, char* argv[] )
{
    CWordCountContext* aCtx = NULL ;

    try {
        if( argc < 2 ) {
            cerr << "Usage: mtwc <file1> [ <file2> ... <filen> ]" 
                 << endl
                 << "Counts lines, words, chars in files and reports" 
                    " results" << endl;
            return 0 ;
        } // if

        int numFiles = argc-1 ; // -1 since program name 
                                // is first argument
        aCtx = new CWordCountContext[ numFiles ];

        CWordCountThreadPool threadPool(NUM_WORKER_THREADS, 
            THREAD_CONCURRENCY, numFiles);
        threadPool.Start();

        for( int i=0; i<numFiles; i++ ) {
            // assign file name to context
            aCtx[i].lpszFilename = argv[i+1]; 
            // kick off a worker thread
            threadPool.PostQueuedCompletionStatus( 
                reinterpret_cast<DWORD>( &(aCtx[i]) ) );
        } // for

        // wait for threads to complete...
        threadPool.WaitUntilDone();
        threadPool.Stop();

        reportResults( aCtx, numFiles );

    } catch( ... ) {
        cerr << "Unhandled exception" << endl ;
    } // catch

    if( aCtx != NULL ) {
        delete [] aCtx ;
    } // if

    return 0 ;
}
/* End of File */