Skip to content

คำอธิบายวิธีทำพร้อม Code สำหรับข้อ toi21_quartet


Problem

สรุปโจทย์

ในเมืองนครแห่งหนึ่ง มีสถานที่อยู่ \(3\) ที่ ได้แก่

  • หมู่บ้าน \(N\) แห่ง แต่ละแห่งถูกกำกับด้วยตัวเลขตั้งแต่ \(1\) ถึง \(N\)
  • ศูนย์ตำบล \(M\) แห่ง แต่ละศูนย์ถูกกำกับด้วยตัวเลขตั้งแต่ \(N+1\) ถึง \(N+M\)
  • ถนน \(N+M−1\) สาย โดยถนนจะไม่มี cycle และ ถนนทุกสายสามารถสัญจรสวนทางกันได้ โดยในเมืองแห่งนี้ เราจะเรียก “กลุ่มอพยพปลอดภัย” เป็นกลุ่มของหมู่บ้าน \(4\) แห่ง ที่จะต้องมีศูนย์ตำบลอย่างน้อย \(1\) ศูนย์ ที่ชาวบ้านจากทั้ง \(4\) หมู่บ้านสามารถเดินทางมายังศูนย์ตำบลศูนย์นั้นได้โดยไม่ต้องใช้ถนนสายเดียวกันเลย

สิ่งที่ต้องทำ

นับจำนวน “กลุ่มอพยพปลอดภัย” โดยให้ระบุคำตอบในรูปแบบของเศษจากการหารจำนวนของ “กลุ่มอพยพปลอดภัย” ด้วย \(10^9+7\)

Constraints

\(4 \leq N \leq 10^5\)
\(1 \leq M \leq N-2\)

Prerequisites

  • DFS/BFS
  • Dynamic Programming
  • DP on Tree

Solution

ข้อสังเกต

จำนวนถนน จะเท่ากับ \(N+M−1\) ซึ่งหมายความว่า Graph ที่ได้มาจะมีลักษณะเป็น Tree

ไอเดียหลัก

เราจะเก็บว่า สำหรับแต่ละศูนย์ตำบล \(u\) ถ้าเราเลือกให้หมู่บ้าน \(4\) แห่งเดินทางมายังศูนย์ตำบล \(u\) โดยไม่ใช้ถนนสายเดียวกันเลย เราจะนับได้กี่กลุ่ม (นับจำนวนกลุ่มอพยพปลอดภัยที่สามารถเดินทางมา \(u\) ได้)

วิธีทำ

เราจะทำ Dynamic Programming on Tree บน Tree ที่ได้รับมา โดยเราจะเก็บใน array of vector (ตั้งชื่อว่า \(cnt\)) ว่า ถ้าเราอยู่ที่โหนด \(u\) แล้ว หากเราเดินตามถนนแต่ละสายที่ติดกับโหนด \(u\) (เดินตาม adjacency list) แล้ว จะมีหมู่บ้านจำนวนกี่หมู่บ้านในเส้นทางนั้น (นับจำนวนหมู่บ้านที่อยู่ใน subtree ของ \(adj[u][i]\) แล้วเก็บไว้ใน \(cnt[u][i]\)) นั่นคือ ขนาดของ \(cnt[u]\) สำหรับแต่ละ \(u\) อาจจะมีขนาดไม่เท่ากันนั่นเองเมื่อเราได้ \(cnt\) มาแล้ว ทีนี้ เราจะทำการ loop จาก \(N+1\) ถึง \(N+M\) (loop แค่ศูนย์ตำบล) แล้วจะคำนวณว่า ถ้าหากว่าเราจะนับจากศูนย์ตำบลที่ \(u\) แล้ว จะมี “กลุ่มอพยพปลอดภัย” ทั้งหมดกี่กลุ่ม และเมื่อเราคำนวณครบสำหรับศูนย์ตำบลทั้ง \(M\) ศูนย์แล้ว เราจะได้คำตอบออกมา

วิธีการคำนวณ: - การคำนวณแบบ Brute Force: เราจะทำการ loop 4 ชั้น สำหรับทุกๆ \(u\) โดยจะ loop \(i,j,k,l\) และสำหรับทุกๆครั้งที่ loop เราจะเพิ่มค่า \(ans\) ไปเป็น \(cnt[u][i]∗cnt[u][j]∗cnt[u][k]∗cnt[u][l]\) แต่วิธีนี้จะช้าเกินไปเพราะในเวลาที่แย่ที่สุดจะต้องใช้ time complexity ในการคำนวณ \(O(n^4)\)

  • การคำนวณโดยใช้ Dynamic Programming: เราจะมี 2d array ตั้งชื่อว่า \(dp\) โดยการคำนวณจะเป็นดังนี้: State ของ Dynamic Programming จะเป็น \(dp[i][j]=\) จำนวนวิธีที่จะเลือกหมู่บ้านมา \(j\) หมู่บ้าน ถ้าหากว่าเราคิดแค่ \(i\) เมืองแรกที่เชื่อมติดกับโหนด \(u\) Recurrence Relation จะเป็น:
  • \(dp[i][0]=1\) (มีวิธีเลือก \(0\) เมืองทั้งหมด \(1\) วิธี ก็คือ ไม่เลือกอะไรเลย)
  • \(dp[i][j]=dp[i−1][j]+(dp[i−1][j−1]∗cnt[u][i])\) (จะมี 2 กรณีได้แก่: 1. ไม่เลือกจาก \(adj[u][i]\) และ 2. เลือกจาก \(adj[u][i]\))

แล้วตอนจบ เราก็แค่เพิ่ม \(ans\) ไปกับค่าของ \(dp[adj[u].size()][4]\)

Summary

  • ใช้ DP on Tree ในการคำนวณหาขนาดของแต่ละ subtree ของแต่ละ node
  • ใช้ DP ในการคำนวณวิธีการเลือกหมู่บ้าน \(4\) แห่งจากศูนย์ตำบลแต่ละศูนย์

Code:

toi21_quartet.cpp
#include <bits/stdc++.h> 

using namespace std; 

#define int long long 

const int mod = 1e9 + 7; 
const int N = 1e5 + 5; 
const int M = 1e5 + 5; 

int n, m, ans, sz[N + M]; 
vector <int> adj[N + M], cnt[N + M]; 

void dfs(int u, int p){ 
    sz[u] = (u <= n); 
    for (auto v : adj[u]) { 
        if (v == p) continue; 
        dfs(v, u); 
        if (sz[v] > 0) cnt[u].emplace_back(sz[v]), sz[u] += sz[v]; 
    } 
    cnt[u].emplace_back(n - sz[u]); 
} 

int32_t main(){ 
    cin.tie(NULL)->sync_with_stdio(false); 
    cin >> n >> m; 
    for (int i = 1; i < n + m; i++) { 
        int u, v; cin >> u >> v; 
        adj[u].emplace_back(v); 
        adj[v].emplace_back(u); 
    } 
    dfs(1, 1); 
    for (int u = n + 1; u <= n + m; u++) { 
        int siz = adj[u].size(); 
        if (siz >= 4) { 
            int dp[siz + 1][5]; 
            memset(dp, 0, sizeof dp); 
            dp[0][0] = 1; 
            for (int i = 1; i <= siz; i++) { 
                dp[i][0] = 1; 
                for (int j = 1; j <= 4; j++) { 
                    dp[i][j] = dp[i - 1][j] + (dp[i - 1][j - 1] * cnt[u][i - 1]); 
                } 
            } 
            ans += dp[siz][4]; ans %= mod; 
        } 
    } 
    cout << ans; 
}

Total Time Complexity

\(O(n+m)\)