C++ 实现
#pragma once
#include <functional>
#include <queue>
#include <tuple>
#include <unordered_map>
#include <vector>
template <typename grid_t, typename cost_t, typename pos_t, typename hash_t, typename equal_t>
class AStar {
public:
using pos_cost_t = std::tuple<pos_t, cost_t>;
using heuristic_f = std::function<cost_t(const pos_t&, const pos_t&)>;
using neighbor_f = std::function<void(const pos_t&, std::vector<pos_cost_t>&)>;
AStar(const grid_t& grid, const heuristic_f heuristic, const neighbor_f neighbor)
: grid_(grid), heuristic_(heuristic), neighbor_(neighbor) {}
bool find_path(const pos_t& from, const pos_t& to, cost_t* cost, std::vector<pos_t>* path)
{
const auto comp_f = [](const pos_cost_t& a, const pos_cost_t& b) {
return std::get<1>(a) > std::get<1>(b);
};
std::priority_queue<pos_cost_t, std::vector<pos_cost_t>, decltype(comp_f)> q(comp_f);
std::vector<pos_cost_t> neighbors;
std::unordered_map<pos_t, cost_t, hash_t, equal_t> g;
std::unordered_map<pos_t, pos_t, hash_t, equal_t> came_from;
g[from] = cost_t();
q.emplace(from, cost_t());
while (!q.empty()) {
auto cur = std::get<0>(q.top());
q.pop();
if (equal_t()(cur, to)) {
if (cost) {
*cost = g[cur];
}
if (path) {
while (came_from.count(cur)) {
path->emplace_back(cur);
cur = came_from[cur];
}
path->emplace_back(from);
std::reverse(path->begin(), path->end());
}
return true;
}
const auto g_cur = g[cur];
neighbor_(cur, neighbors);
for (const auto& [next, cost] : neighbors) {
const auto new_cost = g_cur + cost;
if (!g.count(next) || new_cost < g[next]) {
came_from[next] = cur;
g[next] = new_cost;
q.emplace(next, new_cost + heuristic_(next, to));
}
}
neighbors.clear();
}
return false;
}
private:
const grid_t& grid_;
const heuristic_f heuristic_;
const neighbor_f neighbor_;
};
测试代码
#include "astar.h"
#include <algorithm>
#include <iomanip>
#include <iostream>
using namespace std;
namespace std {
template <typename T1, typename T2>
struct hash<pair<T1, T2>> {
size_t operator()(const pair<T1, T2>& x) const noexcept
{
return hash<T1>()(x.first) ^ hash<T2>()(x.second);
}
};
}
template <typename T>
void dump2d(const T& grid)
{
for (const auto& row : grid) {
for (auto v : row) {
cout << setw(2) << left << v << " ";
}
cout << endl;
}
cout << endl;
}
int shortestPathBinaryMatrix(vector<vector<int>>& grid)
{
using pos = pair<int, int>;
const auto heuristic = [](const pos& from, const pos& to) {
return max(to.first - from.first, to.second - from.second);
};
const auto neighbor = [&grid](const pos& p, vector<tuple<pos, int>>& v) {
constexpr pair<int, int> dirs[] = {
{ -1, -1 }, { 1, 0 }, { 0, 1 }, { -1, 0 }, { 0, -1 }, { 1, 1 }, { 1, -1 }, { -1, 1 }
};
if (!grid[p.first][p.second]) {
for (const auto& dir : dirs) {
const auto x = dir.first + p.first, y = dir.second + p.second;
const auto n = int(grid.size());
if (x >= 0 && x < n && y >= 0 && y < n && !grid[x][y]) {
v.push_back({ { x, y }, 1 });
}
}
}
};
AStar<decltype(grid), int, pos, hash<pos>, equal_to<pos>> astar(grid, heuristic, neighbor);
const int n = int(grid.size());
const pos from{ 0, 0 }, to{ n - 1, n - 1 };
int cost;
vector<pos> path;
if (!astar.find_path(from, to, &cost, &path)) {
return -1;
}
++cost;
cout << "cost: " << cost << "\n";
cout << "path: \n";
for (int i = 0; i < int(path.size()); ++i) {
grid[path[i].first][path[i].second] = i + 1;
}
dump2d(grid);
return cost;
}
int main()
{
vector<vector<int>> grid = {
{ 0, 0, 0 },
{ 1, 1, 0 },
{ 1, 1, 0 }
};
shortestPathBinaryMatrix(grid);
return 0;
}