题目分析:
首先考虑无数条的情况。出现这种情况一定是一条合法路径经过了$ 0 $环中的点。那么预先判出$ 0 $环中的点和其与$ 1 $和$ n $的距离。加起来若离最短路径不超过$ k $则输出$ -1 $,否则这些点必定不被经过,接着dp后效性消失。由于每条边转移了$ k $次它的起点到终点的状态,所以总时间复杂度为$ O(n+mk) $,可以通过所有数据。
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<queue>
#include<vector>
#include<stack>
using namespace std; typedef pair<int,int> pr; const int maxn = ; int n,m,k,p;
vector<pair<int,int> > g[maxn],rg[maxn];
vector <int> ng[maxn]; int dfn[maxn],low[maxn],instack[maxn],scc[maxn],cl;
int f[maxn][],arr[maxn][];
stack <int> sta; void init(){
memset(dfn,,sizeof(dfn));
memset(arr,,sizeof(arr));
memset(low,,sizeof(low));
memset(instack,,sizeof(instack));
memset(scc,,sizeof(scc));
memset(f,,sizeof(f));
cl = ;
} void tarjan(int now){
low[now] = dfn[now] = ++cl;
sta.push(now);
for(int i=;i<ng[now].size();i++){
int k = ng[now][i]; if(instack[k]) continue;
if(dfn[k]) { low[now] = min(low[now],dfn[k]); }
else{ tarjan(k); low[now] = min(low[now],low[k]); }
}
if(low[now] == dfn[now]){
stack<int> tpp;int num = ;
while(true){
int tp = sta.top();sta.pop();tpp.push(tp);num++;
instack[tp] = ; scc[tp]=;
if(now == tp) break;
}
if(num == ) scc[tpp.top()] = ;
}
} void read(){
scanf("%d%d%d%d",&n,&m,&k,&p);
for(int i=;i<=n;i++) g[i].clear();
for(int j=;j<=n;j++) ng[j].clear();
for(int j=;j<=n;j++) rg[j].clear();
for(int i=;i<=m;i++){
int u,v,w; scanf("%d%d%d",&u,&v,&w);
g[u].push_back(make_pair(v,w));
rg[v].push_back(make_pair(u,w));
if(w == ) ng[u].push_back(v);
}
} int res;
int d[maxn][];
priority_queue<pr,vector<pr>,greater<pr> > pq;
void solve_dist(){
memset(d,0x3f,sizeof(d));
pq.push(make_pair(,));
while(!pq.empty()){
pair<int,int> tp = pq.top();pq.pop();
if(d[tp.second][] > 1e9) d[tp.second][] = tp.first;
else continue;
for(int j=;j<g[tp.second].size();j++){
int nxt = g[tp.second][j].first,data = g[tp.second][j].second;
data += d[tp.second][];
if(data > d[nxt][]) continue;
pq.push(make_pair(data,nxt));
}
}
res = d[n][];
pq.push(make_pair(,n));
while(!pq.empty()){
pair<int,int> tp = pq.top();pq.pop();
if(d[tp.second][] > 1e9) d[tp.second][] = tp.first;
else continue;
for(int j=;j<rg[tp.second].size();j++){
int nxt = rg[tp.second][j].first,data = rg[tp.second][j].second;
data += d[tp.second][];
if(data > d[nxt][]) continue;
pq.push(make_pair(data,nxt));
}
}
} void dfs(int now,int lw){
if(arr[now][lw]) return;
int nowd = d[now][]+lw;arr[now][lw] = ;
for(int j=;j<rg[now].size();j++){
int nxt = rg[now][j].first,data = rg[now][j].second;
if(scc[nxt])continue;
int nd = nowd-data; nd -= d[nxt][];
if(nd < ) continue;
if(nd > k) continue;
dfs(nxt,nd);
f[now][lw] += f[nxt][nd]; f[now][lw] %= p;
}
} void work(){
for(int i=;i<=n;i++) if(!dfn[i]) tarjan(i);
if(scc[]||scc[n]) {puts("-1");return;} solve_dist(); for(int i=;i<=n;i++)
if(scc[i]) if(d[i][]+d[i][]<=res+k){puts("-1");return;} f[][] = ;arr[][] = ;int ans = ;
for(int i=;i<=k;i++) {
dfs(n,i);ans += f[n][i]; ans %= p;
}
printf("%d\n",ans);
} int main(){
int t; scanf("%d",&t);
while(t--){ init(); read(); work(); }
return ;
}