「搜索」八数码八境界

  A* IDA* DBFS

Posted by StandHR on July 27, 2018
View times

PS:只实现二、三、四、七、八

一直听说八数码八境界,自己来实现下。配套PDF讲解。

Hint:下方代码全为得到最短路径

表格中上方为POJ1077,下方为HDU1043

  Time Memory
204MS 4996K
T5000MS 6180K
282MS 4548K
62MS 5512K
16MS 5160K
577MS 6192K
16MS 5768K
546MS 7796K
16MS 628K
171MS 1384K

境界一、广搜+STL

境界二、广搜+康托展开+哈希判重+逆序数判无解

#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
using namespace std;

int fac[] = {1, 1, 2, 6, 24, 120, 720, 5040, 40320};

struct node {
    int M[9];
    int whereX;
    int kT;

    int getInverse() {
        int sum = 0;
        for (int i = 0; i < 9; i++) {
            if (M[i] != 0) {
                for (int j = 0; j < i; j++)
                    if (M[j] > M[i])
                        sum++;
            }
        }
        return sum;
    }

    int getKT() {
        int sum = 0;
        for (int i = 0; i < 9; i++) {
            int t = 0;
            for (int j = i + 1; j < 9; j++) {
                if (M[j] < M[i])
                    t++;
            }
            sum += t * fac[8 - i];
        }
        return sum + 1;
    }
};

struct Road{
    int pre;
    char move;
} road[400000];

bool vis[400000];
node endM;
int dir[4] = {1, 3, -1, -3};
string dirPath = "rdlu";

void print(int pos) {
    if (road[pos].pre == -1)
        return;
    print(road[pos].pre);
    putchar(road[pos].move);
}

void bfs(node start) {
    if (start.kT == endM.kT)
        return;

    memset(vis, 0, sizeof(vis));
    queue<node> Q;
    start.kT = start.getKT();
    vis[start.kT] = true;

    Q.push(start);
    while (Q.size()) {
        node now;
        now = Q.front();
        Q.pop();
        for (int i = 0; i < 4; i++) {
            node next = now;
            if ((i == 0 && next.whereX % 3 == 2) || (i == 1 && next.whereX / 3 == 2)
             || (i == 2 && next.whereX % 3 == 0) || (i == 3 && next.whereX / 3 == 0))
                continue;

            swap(next.M[now.whereX], next.M[now.whereX + dir[i]]);
            next.whereX = now.whereX + dir[i];
            next.kT = next.getKT();

            if (vis[next.kT] == false) {
                road[next.kT].move = dirPath[i];
                road[next.kT].pre = now.kT;
                vis[next.kT] = true;
                Q.push(next);
                if (next.kT == endM.kT) {
                    print(next.kT);
                    return;
                }
            }
        }
    }
}

int main()
{
    for (int i = 0; i < 8; i++)
        endM.M[i] = i + 1;
    endM.M[8] = 0;
    endM.kT = endM.getKT();
    char a[105];
    while (gets(a)) {
        node start;
        int j = 0;
        for (int i = 0; i < (int)strlen(a); i++) {
            if (a[i] == ' ')
                continue;
            if (a[i] == 'x') {
                start.M[j] = 0;
                start.whereX = j;
            } else {
                start.M[j] = a[i] - '0';
            }
            j++;
        }
        start.kT = start.getKT();
        road[start.kT].pre = -1;
        if (start.getInverse() & 1) {
            printf("unsolvable\n");
            continue;
        }
        bfs(start);
        putchar('\n');
    }
    return 0;
}

境界三、逆向广搜+康托+哈希打表

#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
using namespace std;

int fac[] = {1, 1, 2, 6, 24, 120, 720, 5040, 40320};

struct node {
    int M[9];
    int whereX;
    int kT;

    int getKT() {
        int sum = 0;
        for (int i = 0; i < 9; i++) {
            int t = 0;
            for (int j = i + 1; j < 9; j++) {
                if (M[j] < M[i])
                    t++;
            }
            sum += t * fac[8 - i];
        }
        return sum + 1;
    }
};

struct vis {
    int pre;
    char path;
} vis[400000];

node endM;
int dir[4] = {1, 3, -1, -3};
string dirPath = "lurd";

void bfs(node start) {
    queue<node> Q;
    vis[start.kT].pre = -1;
    Q.push(start);
    while (Q.size()) {
        node now;
        now = Q.front();
        Q.pop();
        for (int i = 0; i < 4; i++) {
            node next = now;
            if ((i == 0 && next.whereX % 3 == 2) || (i == 1 && next.whereX / 3 == 2)
             || (i == 2 && next.whereX % 3 == 0) || (i == 3 && next.whereX / 3 == 0))
                continue;

            swap(next.M[now.whereX], next.M[now.whereX + dir[i]]);
            next.whereX = now.whereX + dir[i];
            next.kT = next.getKT();

            if (vis[next.kT].pre == 0) {
                vis[next.kT].pre = now.kT;
                vis[next.kT].path = dirPath[i];
                Q.push(next);
            }
        }
    }
}

void print(int pre) {
    if (vis[pre].pre == -1)
        return;
    putchar(vis[pre].path);
    print(vis[pre].pre);
}

int main()
{
    for (int i = 0; i < 8; i++)
        endM.M[i] = i + 1;
    endM.M[8] = 0;
    endM.whereX = 8;
    endM.kT = endM.getKT();
    bfs(endM);
    char a[105];
    while (gets(a)) {
        node start;
        int j = 0;
        for (int i = 0; i < (int)strlen(a); i++) {
            if (a[i] == ' ')
                continue;
            if (a[i] == 'x') {
                start.M[j] = 0;
                start.whereX = j;
            } else {
                start.M[j] = a[i] - '0';
            }
            j++;
        }
        start.kT = start.getKT();
        if (start.kT == endM.kT) {
            putchar('\n');
            continue;
        }
        if (vis[start.kT].pre != 0) {
            print(start.kT);
        } else {
            cout << "unsolvable";
        }
        putchar('\n');
    }
    return 0;
}

境界四、双向广搜+康托+哈希+逆序数

#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
using namespace std;

int fac[] = {1, 1, 2, 6, 24, 120, 720, 5040, 40320};

struct node {
    int M[9];
    int whereX;
    int kT;

    int getInverse() {
        int sum = 0;
        for (int i = 0; i < 9; i++) {
            if (M[i] != 0) {
                for (int j = 0; j < i; j++)
                    if (M[j] > M[i])
                        sum++;
            }
        }
        return sum;
    }

    int getKT() {
        int sum = 0;
        for (int i = 0; i < 9; i++) {
            int t = 0;
            for (int j = i + 1; j < 9; j++) {
                if (M[j] < M[i])
                    t++;
            }
            sum += t * fac[8 - i];
        }
        return sum + 1;
    }
};

struct Road{
    int pre;
    char move;
} road[400000];

int vis[400000];
node endM;
int dir[4] = {1, 3, -1, -3};
string dirPath[2] = {"rdlu", "lurd"};

void print0(int pos) {
    if (road[pos].pre == -1)
        return;
    print0(road[pos].pre);
    putchar(road[pos].move);
}

void print1(int pos) {
    if (road[pos].pre == -1)
        return;
    putchar(road[pos].move);
    print1(road[pos].pre);
}

void dbfs(node start) {
    if (start.kT == endM.kT)
        return;

    memset(vis, 0, sizeof(vis));
    queue<node> Q[2];
    vis[start.kT] = 1;
    vis[endM.kT] = 2;

    Q[0].push(start);
    Q[1].push(endM);
    while (Q[0].size() && Q[1].size()) {
        int j = 0;
        if (Q[1].size() < Q[0].size())
            j++;
        int flag = (int)Q[j].size();
        while (flag --) {
            node now = Q[j].front();
            Q[j].pop();

            for (int i = 0; i < 4; i++) {
                node next = now;
                if ((i == 0 && next.whereX % 3 == 2) || (i == 1 && next.whereX / 3 == 2)
                 || (i == 2 && next.whereX % 3 == 0) || (i == 3 && next.whereX / 3 == 0))
                    continue;

                swap(next.M[now.whereX], next.M[now.whereX + dir[i]]);
                next.whereX = now.whereX + dir[i];
                next.kT = next.getKT();

                if (vis[next.kT] == j + 1)
                    continue;
                else if (vis[next.kT] == 0) {
                    road[next.kT].move = dirPath[j][i];
                    road[next.kT].pre = now.kT;
                    vis[next.kT] = j + 1;
                    Q[j].push(next);
                } else {
                    if (j == 0) {
                        print0(now.kT);
                        putchar(dirPath[j][i]);
                        print1(next.kT);
                    } else {
                        print0(next.kT);
                        putchar(dirPath[j][i]);
                        print1(now.kT);
                    }
                    return;
                }
            }
        }
    }
}

int main()
{
    for (int i = 0; i < 8; i++)
        endM.M[i] = i + 1;
    endM.M[8] = 0;
    endM.whereX = 8;
    endM.kT = endM.getKT();
    road[endM.kT].pre = -1;
    char a[105];
    while (gets(a)) {
        node start;
        int j = 0;
        for (int i = 0; i < (int)strlen(a); i++) {
            if (a[i] == ' ')
                continue;
            if (a[i] == 'x') {
                start.M[j] = 0;
                start.whereX = j;
            } else {
                start.M[j] = a[i] - '0';
            }
            j++;
        }

        start.kT = start.getKT();

        if (start.getInverse() & 1) {
            printf("unsolvable\n");
            continue;
        }

        road[start.kT].pre = -1;
        dbfs(start);

        putchar('\n');
    }
    return 0;
}

境界五、A*+康托+哈希+简单估价函数+逆序数

境界六、A*+康托+哈希+曼哈顿距离+逆序数

境界七、A*+康托+哈希+曼哈顿距离+小顶堆+逆序数

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <utility>
#include <queue>
using namespace std;

typedef pair<int, int> PII;

int fac[] = {1, 1, 2, 6, 24, 120, 720, 5040, 40320};

struct Road{
    int pre;
    char move;
    int G;
} road[400000];

struct node {
    int M[9];
    int h;
    int g;
    int whereX;
    int kT;

    bool operator < (const node &r) const {
        if (g + h == r.g + r.h)
            return h > r.h;
        return g + h > r.g + r.h;
    }

    int getH() {
        int sum = 0;
        for (int i = 0; i < 9; i++) {
            if (M[i] != 0) {
                PII right = PII((M[i] - 1) / 3, (M[i] - 1) % 3);
                PII now = PII(i / 3, i % 3);
                sum += abs(right.first - now.first) + abs(right.second - now.second);
            }
        }
        return sum;
    }

    int getInverse() {
        int sum = 0;
        for (int i = 0; i < 9; i++) {
            if (M[i] != 0) {
                for (int j = 0; j < i; j++)
                    if (M[j] > M[i])
                        sum++;
            }
        }
        return sum;
    }

    int getKT() {
        int sum = 0;
        for (int i = 0; i < 9; i++) {
            int t = 0;
            for (int j = i + 1; j < 9; j++) {
                if (M[j] < M[i])
                    t++;
            }
            sum += t * fac[8 - i];
        }
        return sum + 1;
    }
};


int vis[400000];
node endM;
int dir[4] = {1, 3, -1, -3};
string dirPath = "rdlu";

void print(int pos) {
    if (road[pos].pre == -1)
        return;
    print(road[pos].pre);
    putchar(road[pos].move);
}

void Astar(node start) {
    if (start.kT == endM.kT)
        return;

    memset(vis, 0, sizeof(vis));
    priority_queue<node> Q;
    road[start.kT].G = 0;
    start.g = 0;
    road[start.kT].pre = -1;

    Q.push(start);
    while (Q.size()) {
        node now;
        now = Q.top();
        vis[now.kT] = 1;
        Q.pop();
        for (int i = 0; i < 4; i++) {
            node next = now;
            if ((i == 0 && next.whereX % 3 == 2) || (i == 1 && next.whereX / 3 == 2)
             || (i == 2 && next.whereX % 3 == 0) || (i == 3 && next.whereX / 3 == 0))
                continue;

            swap(next.M[now.whereX], next.M[now.whereX + dir[i]]);
            next.whereX = now.whereX + dir[i];
            next.kT = next.getKT();
            next.g = now.g + 1;

            if (vis[next.kT] == 1)
                continue;
            else if (vis[next.kT] == 2) {
                if (next.g > road[next.kT].G)
                    continue;
                else {
                    road[next.kT].G = next.g;
                    road[next.kT].pre = now.kT;
                    road[next.kT].move = dirPath[i];
                    Q.push(next);
                }
            } else if (vis[next.kT] == 0) {
                road[next.kT].move = dirPath[i];
                road[next.kT].pre = now.kT;
                road[next.kT].G = next.g;
                next.h = next.getH();
                vis[next.kT] = 2;
                Q.push(next);
                if (next.kT == endM.kT) {
                    print(next.kT);
                    return;
                }
            }
        }
    }
}

int main()
{
    for (int i = 0; i < 8; i++)
        endM.M[i] = i + 1;
    endM.M[8] = 0;
    endM.kT = endM.getKT();
    char a[105];
    while (gets(a)) {
        node start;
        int j = 0;
        for (int i = 0; i < (int)strlen(a); i++) {
            if (a[i] == ' ')
                continue;
            if (a[i] == 'x') {
                start.M[j] = 0;
                start.whereX = j;
            } else {
                start.M[j] = a[i] - '0';
            }
            j++;
        }
        start.kT = start.getKT();
        if (start.getInverse() & 1) {
            printf("unsolvable\n");
            continue;
        }
        Astar(start);
        putchar('\n');
    }
    return 0;
}

境界八、IDA*+曼哈顿距离+逆序数

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <utility>
using namespace std;

typedef pair<int, int> PII;
#define INF 1e6

char Path[100005];

struct node {
    int M[9];
    int g;
    int h;
    int whereX;

    bool operator == (const node &r) const {
        for (int i = 0; i < 9; i++) {
            if (M[i] != r.M[i])
                return false;
        }
        return true;
    }

    int get_H() {
        int sum = 0;
        for (int i = 0; i < 9; i++) {
            if (M[i] != 0) {
                PII right = PII((M[i] - 1) / 3, (M[i] - 1) % 3);
                PII now = PII(i / 3, i % 3);
                sum += abs(right.first - now.first) + abs(right.second - now.second);
            }
        }
        return sum;
    }

    int getInverse() {
        int sum = 0;
        for (int i = 0; i < 9; i++) {
            if (M[i] != 0) {
                for (int j = 0; j < i; j++)
                    if (M[j] > M[i])
                        sum++;
            }
        }
        return sum;
    }
};

node endM;
int dir[4] = {1, 3, -1, -3};
string dirPath = "rdlu";

int dfs(node now, int maxDepth) {
    int f = now.g + now.h;
    if (f > maxDepth)
        return f;
    int Min = INF;
    for (int i = 0; i < 4; i++) {
        node next = now;
        if (now.g > 0 && dirPath[((4 + i) - 2) % 4] == Path[now.g - 1])
            continue;
        if ((i == 0 && next.whereX % 3 == 2) || (i == 1 && next.whereX / 3 == 2)
         || (i == 2 && next.whereX % 3 == 0) || (i == 3 && next.whereX / 3 == 0))
            continue;

        swap(next.M[now.whereX], next.M[now.whereX + dir[i]]);
        next.whereX = now.whereX + dir[i];
        next.h = next.get_H();
        next.g = now.g + 1;
        Path[now.g] = dirPath[i];

        if (next == endM) {
            Path[next.g] = '\0';
            return 0;
        }

        int t = dfs(next, maxDepth);
        if (t == 0)
            return 0;
        if (t < Min)
            Min = t;
    }
    return Min;
}

void IDA(node start) {
    if (start == endM)
        return;
    start.g = 0;
    start.h = start.get_H();
    int maxDepth = start.h;
    while (true) {
        int t = dfs(start, maxDepth);
        if (t == 0) {
            cout << Path;
            return;
        }
        maxDepth = t;
    }
}

int main()
{
    for (int i = 0; i < 8; i++)
        endM.M[i] = i + 1;
    endM.M[8] = 0;
    char a[105];
    while (gets(a)) {
        node start;
        int j = 0;
        for (int i = 0; i < (int)strlen(a); i++) {
            if (a[i] == ' ')
                continue;
            if (a[i] == 'x') {
                start.M[j] = 0;
                start.whereX = j;
            } else {
                start.M[j] = a[i] - '0';
            }
            j++;
        }
        if (start.getInverse() & 1) {
            printf("unsolvable\n");
            continue;
        }
        IDA(start);
        putchar('\n');
    }
    return 0;
}