Sunday, November 8, 2009

On aliakseis' Opus #2

Somewhere within the discussion threads on the Yahoo! Groups for AZSPCS is a link to a code that solves the D=3 problem by brute force. After careful inspection of the code, I can safely say that I've fully understood it now.

The algorithm generates all admissible basis to find one with the best score. At each level of the recursion, it tries in increasing order all numbers strictly larger than the previous number, but less than or equal to the score of the numbers we have so far. Score is implicit in the global data structure that it maintains throughout the search. The common leave-it-as-you-found-it principle is enforced, and thus changes made to the data structure needs to be undone upon backing out of the recursion.

Conceptually, the data structure can be described as follows. Let count[x][d] be the number of ways the number x can be represented as a sum of exactly d numbers (d = 1, 2, or 3) in the basis we have so far. The current score, then, will be the lowest value s such that count[s][1], count[s][2], and count[s][3] are all zeroes.

We start with no numbers in the basis, so count is initialized to all zeroes. Then, when a number b is introduced to the basis, we must update count as follows:
count[b][1]++
for j = b+1 to 3*b
   count[j][2] += count[j-b][1]
   count[j][3] += count[j-b][2]
It should be noted here that count needs to be updated left to right because of the directional data dependency. Conversely, to undo this update, we would need to go subtracting right to left, and decrementing count[b][1] after the loop.

The code essentially implements the above algorithm, spiced up with various low-level optimizations. The most important one here is the packing of count[x][3], count[x][2], count[x][1] into one number. This is quite ingenious because the two O(D) operations (checking if all are zeroes, and the shifted addition/subtraction) are now O(1).

Another obvious optimization is the use of what I later found out is called the Duff's device. I'm familiar with the concept of loop unrolling, but was nonetheless previously unaware of this construct (or even that a switch statement can be (ab)used in this manner!).

And that pretty much describes the code in a nutshell.
/* From http://aliakseis.livejournal.com/2791.html */

#include <iostream>

using std::cout; 

long g_buf[1024]; 
int g_arrValues[25]; 
int g_arrBest[25]; 
int g_nLevels; 
int g_nBest; 

void Step(int nLevel) 
{
    int nVal = g_arrValues[nLevel++]; 

    if(nLevel == g_nLevels) 
    {
        while(g_buf[nVal++] != 0); 

        if(nVal > g_nBest) 
        {
            for(int i = 0; i < nLevel; i++) 
                g_arrBest[i] = g_arrValues[i];
            g_nBest = nVal;
        }
    }
    else 
    {
        do 
        {
            ++g_buf[nVal++];

            long* pFrom = g_buf; 
            long* pTo = g_buf + nVal;
            int nCount = (nVal + 7) / 8; 

            switch (nVal & 7) 
            {
            case 0: 
                do 
                {
                    *pTo++ += short(*pFrom++) << 8; *pTo++ += short(*pFrom++) << 8; 
            case 7: *pTo++ += short(*pFrom++) << 8; *pTo++ += short(*pFrom++) << 8; 
            case 6: *pTo++ += short(*pFrom++) << 8; *pTo++ += short(*pFrom++) << 8; 
            case 5: *pTo++ += short(*pFrom++) << 8; *pTo++ += short(*pFrom++) << 8; 
            case 4: *pTo++ += short(*pFrom++) << 8; *pTo++ += short(*pFrom++) << 8; 
            case 3: *pTo++ += short(*pFrom++) << 8; *pTo++ += short(*pFrom++) << 8; 
            case 2: *pTo++ += short(*pFrom++) << 8; *pTo++ += short(*pFrom++) << 8; 
            case 1: *pTo++ += short(*pFrom++) << 8; *pTo++ += short(*pFrom++) << 8; 
                }
                while(--nCount > 0); 
            }

            g_arrValues[nLevel] = nVal;

            Step(nLevel);

            pFrom = g_buf + nVal * 2 - 1;
            pTo = g_buf + nVal * 3 - 1;
            nCount = (nVal + 7) / 8;

            switch (nVal & 7) 
            {
            case 0: 
                do 
                {
                    *pTo-- -= short(*pFrom--) << 8; *pTo-- -= short(*pFrom--) << 8; 
            case 7: *pTo-- -= short(*pFrom--) << 8; *pTo-- -= short(*pFrom--) << 8; 
            case 6: *pTo-- -= short(*pFrom--) << 8; *pTo-- -= short(*pFrom--) << 8; 
            case 5: *pTo-- -= short(*pFrom--) << 8; *pTo-- -= short(*pFrom--) << 8; 
            case 4: *pTo-- -= short(*pFrom--) << 8; *pTo-- -= short(*pFrom--) << 8; 
            case 3: *pTo-- -= short(*pFrom--) << 8; *pTo-- -= short(*pFrom--) << 8; 
            case 2: *pTo-- -= short(*pFrom--) << 8; *pTo-- -= short(*pFrom--) << 8; 
            case 1: *pTo-- -= short(*pFrom--) << 8; *pTo-- -= short(*pFrom--) << 8; 
                }
                while(--nCount > 0); 
            }
        }
        while(--g_buf[nVal-1] != 0); 
    }
}

int main(int argc, char* argv[]) 
{
    g_nBest = 0;
    g_arrValues[0] = 1;
    g_buf[0] = 0x00000001;
    g_buf[1] = 0x00000100;
    g_buf[2] = 0x00010000;
    g_nLevels = 9;

    Step(0);

    cout << "Max sum: " << g_nBest - 1 <<'\n';

    for (int i = 0; i < g_nLevels; ++i) 
        cout << g_arrBest[i] << ' ';

    cout << '\n';

    return 0; 
}

And the cycle is complete.

No comments:

Post a Comment