算法(二)数据结构


算法基础课

单链表

模板:

// head存储链表头,e[]存储节点的值,ne[]存储节点的next指针,idx表示当前用到了哪个节点
int head, e[N], ne[N], idx;

// 初始化
void init()
{
    head = -1;
    idx = 0;
}

// 在链表头插入一个数a
void insert(int a)
{
    e[idx] = a, ne[idx] = head, head = idx ++ ;
}

// 将头结点删除,需要保证头结点存在
void remove()
{
    head = ne[head];
}
public static int head, idx;
public static int e[] = new int[N];
public static int ne[] = new int[N];

// 初始化数据
public static void init() {
    head = -1;
    idx = 0;
}

// 将val插入到头结点
public static void addToHead(int val) {
    e[idx] = val;
    ne[idx] = head;
    head = idx++;
}

// 将下标是k的点后面的点删掉
public static void remove(int k) {
    ne[k] = ne[ne[k]];
}

public static void add(int k, int val) {
    e[idx] = val;
    ne[idx] = ne[k];
    ne[k] = idx++;
}

原题链接:单链表

#include <iostream>

using namespace std;

const int N = 100010;

//head 表示头结点的下标
//e[i] 表示节点i的值
//ne[i] 表示节点i的next指针是多少
//idx 存储当前已经用到了哪些点
int head,e[N],ne[N],idx;

//初始化
void init()
{
    head = -1;
    idx = 0;
}

//将x插到头结点
void add_to_head(int x)
{
    e[idx] = x; //存储新的节点值
    ne[idx] = head; //将新的节点插入到首节点前面
    head = idx++; //将head指向新节点
}

//将x插入到下标是k的节点的后面
void add(int k,int x)
{
	e[idx] = x;
    ne[idx] = ne[k];
    ne[k] = idx++;
}

//将下标是k的点的后面的点删除掉
void remove(int k)
{
    ne[k] = ne[ne[k]];
}
    
int main()
{
    init();
    int m ;
    cin >> m;
    while(m--)
    {
        int k , x ; 
        char op;
        cin >> op;
        if(op == 'H')
        {
            cin >> x;
            add_to_head(x);
        }
        else if(op == 'D')
        {
            cin >> k;
            if(!k) head = ne[head];
            remove(k-1);
        }
        else
        {
            cin >> k >> x ;
            add(k-1,x);
        }
    }
    
    for(int i = head ; i != -1 ; i = ne[i]) cout << e[i] << ' ' ;
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
	public static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
	public static StreamTokenizer in = new StreamTokenizer(new InputStreamReader(System.in));
	public static PrintWriter out = new PrintWriter(new OutputStreamWriter(System.out));

	public static int nextInt() throws Exception {
		in.nextToken();
		return (int) in.nval;
	}

	public static int N = 100010, M;
	public static int head, idx;
	public static int e[] = new int[N];
	public static int ne[] = new int[N];

	// 初始化数据
	public static void init() {
		head = -1;
		idx = 0;
	}

	// 将val插入到头结点
	public static void addToHead(int val) {
		e[idx] = val;
		ne[idx] = head;
		head = idx++;
	}

	// 将下标是k的点后面的点删掉
	public static void remove(int k) {
		ne[k] = ne[ne[k]];
	}

	public static void add(int k, int val) {
		e[idx] = val;
		ne[idx] = ne[k];
		ne[k] = idx++;
	}

	public static void main(String[] agrs) throws Exception {
		init();
		M = Integer.parseInt(br.readLine());
		while (M-- > 0) {
			String[] s = br.readLine().split(" ");
			if (s[0].equals("H")) {
				int val = Integer.parseInt(s[1]);
				addToHead(val);
			} else if (s[0].equals("I")) {
				int k = Integer.parseInt(s[1]);
				int val = Integer.parseInt(s[2]);
				add(k - 1, val); // 第 k个结点的下标为 k-1, 所以插入到下标为 k-1结点的后面
			} else {
				int k = Integer.parseInt(s[1]);

				if (k == 0) {
					head = ne[head];
				} else
					remove(k - 1);
			}
		}
        for (int i = head; i != -1; i = ne[i]) {
            out.print(e[i] + " ");
        }
        out.flush();
	}
}

双链表

模板:

// e[]表示节点的值,l[]表示节点的左指针,r[]表示节点的右指针,idx表示当前用到了哪个节点
int e[N], l[N], r[N], idx;

// 初始化
void init()
{
    //0是左端点,1是右端点
    r[0] = 1, l[1] = 0;
    idx = 2;
}

// 在节点a的右边插入一个数x
void insert(int a, int x)
{
    e[idx] = x;
    l[idx] = a, r[idx] = r[a];
    l[r[a]] = idx, r[a] = idx ++ ;
}

// 删除节点a
void remove(int a)
{
    l[r[a]] = l[a];
    r[l[a]] = r[a];
}
public static int idx;
public static int e[] = new int[N];
public static int l[] = new int[N];
public static int r[] = new int[N];
public static void init(){
    //0是左端点,1是右端点
    r[0] = 1;
    l[1] = 0;
    idx = 2;
}

public static void add(int a,int x){
    e[idx] = x;
    l[idx] = a;
    r[idx] = r[a];
    l[r[a]] = idx;
    r[a] = idx;
    idx ++ ;
}

public static void remove(int a){
    l[r[a]] = l[a];
    r[l[a]] = r[a];
}

原题链接:双链表

#include<iostream>

using namespace std;

const int N = 100010;

int r[N], l[N], idx , e[N];
void init()
{
    //0是左端点,1是右端点
    r[0] = 1, l[1] = 0;
    idx = 2;
}

//在下标是k的右边插入一个数
void add(int a,int x)
{
    e[idx] = x;
    l[idx] = a, r[idx] = r[a];
    l[r[a]] = idx , r[a] = idx;
    idx ++;
}

void remove(int a)
{
    l[r[a]] = l[a];
    r[l[a]] = r[a];
}

int main()
{
    int m ;
    cin >> m;
    init();
    while(m--)
    {
        string op ;
        int k,x;
        cin >> op;
        if(op == "L")
        {
            cin >> x;
            add(0,x);
        }
        else if(op == "R")
        {
            cin >> x;
            add(l[1],x);
        }
        else if(op == "D")
        {
            cin >> k;
            remove(k+1);
        }
        else if(op == "IL")
        {
            cin >> k >> x;
            add(l[k+1], x);
        }
        else 
        {
            cin >> k >> x;
            add(k+1 , x);
        }
    }
    
    for(int i = r[0]; i != 1 ; i = r[i]) cout << e[i] << ' ';
    return 0;
}
import java.io.*;
import java.util.*;

public class Main{
    public static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    public static PrintWriter out = new PrintWriter(new OutputStreamWriter(System.out));
    public static int N = 100010;
    public static int head, idx;
    public static int e[] = new int[N];
    public static int l[] = new int[N];
    public static int r[] = new int[N];
    public static void init(){
        //0是左端点,1是右端点
        r[0] = 1;
        l[1] = 0;
        idx = 2;
    }
    
    public static void add(int a,int x){
        e[idx] = x;
        l[idx] = a;
        r[idx] = r[a];
        l[r[a]] = idx;
        r[a] = idx;
        idx ++ ;
    }
    
    public static void remove(int a){
        l[r[a]] = l[a];
        r[l[a]] = r[a];
    }
    
    public static void main(String[] args)throws Exception{
        init();
        int M =Integer.parseInt(br.readLine());
        while(M -- > 0){
            int k,x;
            String[] s = br.readLine().split(" ");
            if(s[0].equals("L")){
                x = Integer.parseInt(s[1]);
                add(0,x);
            }else if(s[0].equals("R")){
                x = Integer.parseInt(s[1]);
                add(l[1],x);
            }else if(s[0].equals("D")){
                k = Integer.parseInt(s[1]);
                remove(k+1);
            }else if(s[0].equals("IL")){
                k = Integer.parseInt(s[1]);
                x = Integer.parseInt(s[2]);
                add(l[k+1], x);
            }else {
                k = Integer.parseInt(s[1]);
                x = Integer.parseInt(s[2]);
                add(k+1, x);
            }
        }
        
        for(int i = r[0]; i != 1; i = r[i]) out.print(e[i] + " ");
        out.flush();
    }
}

模板:

// tt表示栈顶
int stk[N], tt = 0;

// 向栈顶插入一个数
stk[ ++ tt] = x;

// 从栈顶弹出一个数
tt -- ;

// 栈顶的值
stk[tt];

// 判断栈是否为空
if (tt > 0)
{

}

原题链接:https://www.acwing.com/problem/content/830/

#include<iostream>

using namespace std;

const int N = 100010;
int stk[N],tt ;

void push(int x)
{
    stk[++tt] = x;
}

void pop()
{
    tt --;
}

bool isempty()
{
    return tt == 0 ;
}

int query()
{
    return stk[tt];
}

int main()
{
    int m ;
    cin >> m;
    while(m--)
    {
        string op;
        cin >> op;
        if(op == "push"){
            int x;
            cin >> x;
            push(x);
        }else if( op == "pop" )
        {
            pop();
        }else if( op == "query")
        {
            cout << query() << endl;
        }else if( op == "empty")
        {
            bool flag = isempty();
            if(flag) cout << "YES" << endl;
            else cout << "NO" << endl;
        }
    }
    return 0;
}

队列

普通队列

// hh 表示队头,tt表示队尾
int q[N], hh = 0, tt = -1;

// 向队尾插入一个数
q[ ++ tt] = x;

// 从队头弹出一个数
hh ++ ;

// 队头的值
q[hh];

// 判断队列是否为空
if (hh <= tt)
{

}

原题链接:https://www.acwing.com/problem/content/831/

#include<iostream>

using namespace std;

const int N = 100010;

int q[N], hh = 0 , tt = -1 ;

// hh ... tt 
void push(int x)
{
    q[++tt] = x;
}

void pop()
{
    hh ++ ;
}

int query()
{
    return q[hh];
}

bool isempty()
{
    return hh > tt ;
}

int main()
{
    int m;
    cin >> m;
    while(m--)
    {
        string op;
        cin >> op;
        int x;
        if( op == "push" ) {
            cin >> x;
            push(x);
        }else if( op == "pop")
        {
            pop();
        }else if( op == "query")
        {
            cout << query() << endl;
        }else if( op == "empty")
        {
            bool flag = isempty();
            if(flag) cout << "YES" << endl;
            else cout << "NO" <<endl;
        }        
    }
    return 0;
}

循环队列

// hh 表示队头,tt表示队尾的后一个位置
int q[N], hh = 0, tt = 0;

// 向队尾插入一个数
q[tt ++ ] = x;
if (tt == N) tt = 0;

// 从队头弹出一个数
hh ++ ;
if (hh == N) hh = 0;

// 队头的值
q[hh];

// 判断队列是否为空
if (hh != tt)
{

}
// 在队尾插入元素,在队头弹出元素
int q[N], hh, tt = -1;

// 插入
q[++tt] = x;

// 弹出
hh++;

// 判断队列是否为空
if(hh <= tt) not empty;
else empty;

// 取出队头元素
q[hh];

单调栈

模板:

常见模型:找出每个数左边离它最近的比它大/小的数
int tt = 0;
for (int i = 1; i <= n; i ++ )
{
    while (tt && check(stk[tt], i)) tt -- ;
    stk[ ++ tt] = i;
}

原题链接:单调栈

#include<iostream>

using namespace std;
const int N = 100010;
int stk[N] , tt;

int main()
{
    int n;
    cin >> n;
    for(int i = 0 ; i < n ; i ++ )
    {  
        int x;
        cin >> x;
        while(tt && stk[tt] >= x) tt --;
        if(tt) cout << stk[tt] << ' ';
        else cout << -1 << ' ';
        stk[++tt] = x;
    }
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
	public static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
	public static StreamTokenizer in = new StreamTokenizer(new InputStreamReader(System.in));
	public static PrintWriter out = new PrintWriter(new OutputStreamWriter(System.out));

	public static int nextInt() throws Exception {
		in.nextToken();
		return (int) in.nval;
	}

	public static int N = 100010;
	public static int n;
	public static Stack<Integer> st = new Stack<Integer>();

	public static void main(String[] args) throws Exception {
		n = nextInt();
		while (n-- > 0) {
			int x = nextInt();
			while (!st.empty() && st.peek() >= x) {
				st.pop();
			}
			if (st.empty())
				out.print("-1 ");
			else
				out.print(st.peek() + " ");
			st.push(x);
		}
		out.flush();
	}
}

单调队列

模板:

常见模型:找出滑动窗口中的最大值/最小值
int hh = 0, tt = -1;
for (int i = 0; i < n; i ++ )
{
    while (hh <= tt && check_out(q[hh])) hh ++ ;  // 判断队头是否滑出窗口
    while (hh <= tt && check(q[tt], i)) tt -- ;
    q[ ++ tt] = i;
}

原题链接:滑动窗口

#include <iostream>

using namespace std;

const int N = 1000010;

int n,k;
int a[N],q[N]; //q队列存的是下标

int main()
{
    scanf("%d",&n);
    scanf("%d",&k);
    for(int i = 0 ;i < n; i++) scanf("%d",&a[i]);
    
    int hh = 0, tt = -1;
    for ( int i = 0; i < n ; i++)
    {
        if(hh <= tt && i - k + 1 > q[hh]) hh++; //判断队头是否已经滑出窗口
        while(hh <= tt && a[q[tt]] >= a[i]) tt--;    //如果插入的数比队列中的数小,则将该数插入到队列中,需始终保持队列中的数严格递增,且大于插入数的数可以抛弃,因为不再需要,因此只需移动tt即可
        q[ ++ tt] = i; //插入队列
        if(i >= k - 1) printf("%d ", a[q[hh]]);    //输入最小的数,即队头的数
    }
    puts("");
    
    hh = 0, tt = -1;
    for ( int i = 0; i < n ; i++)
    {
        //判断队头是否已经滑出窗口
        if(hh <= tt && i - k + 1 > q[hh]) hh++;
        while(hh <= tt && a[q[tt]] <= a[i]) tt--;
        q[ ++ tt] = i;
        if(i >= k - 1) printf("%d ", a[q[hh]]);   
    }
    puts("");
    
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
	public static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
	public static StreamTokenizer in = new StreamTokenizer(new InputStreamReader(System.in));
	public static PrintWriter out = new PrintWriter(new OutputStreamWriter(System.out));

	public static int nextInt() throws Exception {
		in.nextToken();
		return (int) in.nval;
	}

	public static int N = 1000010;
	public static int n, k;
	public static int a[] = new int[N];
	public static Deque<Integer> q = new LinkedList<Integer>();

	public static void main(String[] args) throws Exception {
		n = nextInt();
		k = nextInt();
		for (int i = 0; i < n; i++)
			a[i] = nextInt();

		// 维护一个单调递增的队列,每次弹出队头
		for (int i = 0; i < n; i++) {
			if (!q.isEmpty() && i - k + 1 > q.getFirst())
				q.removeFirst();
			while (!q.isEmpty() && a[q.getLast()] >= a[i])
				q.removeLast();
			q.addLast(i);
			if (i >= k - 1)
				out.print(a[q.getFirst()] + " ");
		}
		out.println();
		q.clear();
		for (int i = 0; i < n; i++) {
			if (!q.isEmpty() && i - k + 1 > q.getFirst())
				q.removeFirst();
			while (!q.isEmpty() && a[q.getLast()] <= a[i])
				q.removeLast();
			q.addLast(i);
			if (i >= k - 1)
				out.print(a[q.getFirst()] + " ");
		}
		out.flush();
	}
}

KMP

模板:

// s[]是长文本,p[]是模式串,n是s的长度,m是p的长度
求模式串的Next数组:
for (int i = 2, j = 0; i <= m; i ++ )
{
    while (j && p[i] != p[j + 1]) j = ne[j];
    if (p[i] == p[j + 1]) j ++ ;
    ne[i] = j;
}

// 匹配
for (int i = 1, j = 0; i <= n; i ++ )
{
    while (j && s[i] != p[j + 1]) j = ne[j];
    if (s[i] == p[j + 1]) j ++ ;
    if (j == m)
    {
        j = ne[j];
        // 匹配成功后的逻辑
    }
}

原题链接:https://www.acwing.com/problem/content/833/

暴力算法

s[N],p[M];
for(int i = 1;i <= n;i ++)
{
    bool flag = true;
    for(int j = 1; j<= m; j++)
        if(s[i+j-1] != p[j])
        {
            flag = false;
            break;
        }
}

KMP算法

#include <iostream>

using namespace std;

const int N = 1e5+10 , M = 1e6+10;

int n, m;
int ne[N];
char s[M], p[N];

int main()
{
    cin >> n >> p + 1 >> m >> s + 1;

    //求next数组
    for (int i = 2, j = 0; i <= n; i ++ )
    {
        while (j && p[i] != p[j + 1]) j = ne[j];
        if (p[i] == p[j + 1]) j ++ ;
        ne[i] = j;
    }
    //kmp匹配操作
    for (int i = 1, j = 0; i <= m; i ++ )
    {
        while (j && s[i] != p[j + 1]) j = ne[j];//若不能再匹配,则j退一步到与该后缀相同的前缀字符串处(这里的后缀和前缀指的都是匹配字符串)
        if (s[i] == p[j + 1]) j ++ ;
        if (j == n)
        {
            printf("%d ", i - n);
            j = ne[j];
        }
    }

    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
	public static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
	public static PrintWriter out = new PrintWriter(new OutputStreamWriter(System.out));
	public static int N = 100010, M = 1000010;
	public static int n, m;
	public static char s[] = new char[M]; // 模式串
	public static char p[] = new char[N]; // 总串
	public static int ne[] = new int[N];

	public static void main(String[] args) throws Exception {
		n = Integer.parseInt(br.readLine());
		String P = br.readLine();
		for (int i = 1; i <= n; i++) {
			p[i] = P.charAt(i - 1);
		}
		m = Integer.parseInt(br.readLine());
		String S = br.readLine();
		for (int i = 1; i <= m; i++) {
			s[i] = S.charAt(i - 1);
		}

		// 构造前缀数组
		for (int i = 2, j = 0; i <= n; i++) {
			while (j != 0 && p[i] != p[j + 1])
				j = ne[j];
			if (p[i] == p[j + 1])
				j++;
			ne[i] = j;
		}

		// kmp匹配操作
		for (int i = 1, j = 0; i <= m; i++) {
			while (j != 0 && s[i] != p[j + 1])
				j = ne[j];
			if (s[i] == p[j + 1])
				j++;
			if (j == n) {
				out.print(i - n + " ");
				j = ne[j];
			}
		}
		out.flush();
	}
}

Trie树

快速存储字符串集合的数据结构

模板

int son[N][26], cnt[N], idx;
// 0号点既是根节点,又是空节点
// son[][]存储树中每个节点的子节点
// cnt[]存储以每个节点结尾的单词数量

// 插入一个字符串
void insert(char *str)
{
    int p = 0;
    for (int i = 0; str[i]; i ++ )
    {
        int u = str[i] - 'a';
        if (!son[p][u]) son[p][u] = ++ idx;
        p = son[p][u];
    }
    cnt[p] ++ ;
}

// 查询字符串出现的次数
int query(char *str)
{
    int p = 0;
    for (int i = 0; str[i]; i ++ )
    {
        int u = str[i] - 'a';
        if (!son[p][u]) return 0;
        p = son[p][u];
    }
    return cnt[p];
}
public static int son[][] = new int[N][26];
public static int cnt[] = new int[N];
public static int idx;

public static void insert(char[] s) {
    int p = 0;
    for (int i = 0; i < s.length; i++) {
        int u = s[i] - 'a';
        if (son[p][u] == 0)
            son[p][u] = ++idx;
        p = son[p][u];
    }
    cnt[p]++;
}

public static int query(char[] s) {
    int p = 0;
    for (int i = 0; i < s.length; i++) {
        int u = s[i] - 'a';
        if (son[p][u] == 0)
            return 0;
        p = son[p][u];
    }
    return cnt[p];
}

原题链接:Trie字符串统计

#include<iostream>

using namespace std;

const int N = 100010;

int son[N][26],cnt[N],idx;
//son[][] => tries树每个点的所有儿子
//cnt[] => 以当前这个点的结尾的单词有多少个
//idx => 当前用到了哪个下标 下标为0的点既是根节点,又是空节点
char str[N];

void insert(char str[]) 
{
    int p = 0; //根节点开始
    for(int i=0;str[i];i++)
    {
        int u = str[i] - 'a'; //取插入字符串的每个字符的编号 a-z => 0-25
        if(!son[p][u]) son[p][u] = ++idx; //创建该节点
        p = son[p][u]; //走原有或刚创建的子节点
    }
    cnt[p] ++ //记录个数加1
}

int query(char str[])
{
    int p = 0;
    for(int i = 0;str[i];i++)
    {
        int u = str[i] - 'a';
        if(!son[p][u]) return 0;
        p = son[p][u];
    }
    return cnt[p];
}

int main()
{
    int n;
    scanf("%d" , &n);
    while(n--)
    {
        char op[2];
        scanf("%s%s",op,str);
        if(op[0] == 'I') insert(str);
        else printf("%d\n",query(str));
    }
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
	public static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
	public static PrintWriter out = new PrintWriter(new OutputStreamWriter(System.out));
	public static int N = 100010;
	public static int n;
	public static int son[][] = new int[N][26];
	public static int cnt[] = new int[N];
	public static int idx;

	public static void insert(char[] s) {
		int p = 0;
		for (int i = 0; i < s.length; i++) {
			int u = s[i] - 'a';
			if (son[p][u] == 0)
				son[p][u] = ++idx;
			p = son[p][u];
		}
		cnt[p]++;
	}

	public static int query(char[] s) {
		int p = 0;
		for (int i = 0; i < s.length; i++) {
			int u = s[i] - 'a';
			if (son[p][u] == 0)
				return 0;
			p = son[p][u];
		}
		return cnt[p];
	}

	public static void main(String[] args) throws Exception {
		n = Integer.parseInt(br.readLine());
		while (n-- > 0) {
			String[] strings = br.readLine().split(" ");
			char str[] = new char[N];
			if (strings[0].equals("I")) {
			    str = strings[1].toCharArray();
				insert(str);
			} else {
			    str = strings[1].toCharArray();
				int res = query(str);
				out.println(res);
			}
		}
		out.flush();
	}
}

并查集

  1. 将两个集合合并
  2. 询问两个元素是否在一个集合当中

基本原理:每个集合用一颗树表示。树根的编号就是整个集合的编号。每个节点存储它的父节点,p[x]表示x的父节点

  • 判断树根?if(p[x]=x)
  • 求x的集合编号?while(p[x]!=x) x=p[x];
  • 合并两个集合:px是x的集合编号,py是y的集合编号 p[x]=y

优化:路径压缩 遍历到根节点后,直接将所有的路径上的点与根节点相连

模板:

(1)朴素并查集:

    int p[N]; //存储每个点的祖宗节点

    // 返回x的祖宗节点
    int find(int x)
    {
        if (p[x] != x) p[x] = find(p[x]);
        return p[x];
    }

    // 初始化,假定节点编号是1~n
    for (int i = 1; i <= n; i ++ ) p[i] = i;

    // 合并a和b所在的两个集合:
    p[find(a)] = find(b);


(2)维护size的并查集:

    int p[N], size[N];
    //p[]存储每个点的祖宗节点, size[]只有祖宗节点的有意义,表示祖宗节点所在集合中的点的数量

    // 返回x的祖宗节点
    int find(int x)
    {
        if (p[x] != x) p[x] = find(p[x]);
        return p[x];
    }

    // 初始化,假定节点编号是1~n
    for (int i = 1; i <= n; i ++ )
    {
        p[i] = i;
        size[i] = 1;
    }

    // 合并a和b所在的两个集合:
    size[find(b)] += size[find(a)];
    p[find(a)] = find(b);


(3)维护到祖宗节点距离的并查集:

    int p[N], d[N];
    //p[]存储每个点的祖宗节点, d[x]存储x到p[x]的距离

    // 返回x的祖宗节点
    int find(int x)
    {
        if (p[x] != x)
        {
            int u = find(p[x]);
            d[x] += d[p[x]];
            p[x] = u;
        }
        return p[x];
    }

    // 初始化,假定节点编号是1~n
    for (int i = 1; i <= n; i ++ )
    {
        p[i] = i;
        d[i] = 0;
    }

    // 合并a和b所在的两个集合:
    p[find(a)] = find(b);
    d[find(a)] = distance; // 根据具体问题,初始化find(a)的偏移量distance

原题链接:合并集合

#include<iostream>

using namespace std;

const int N = 100010;

int p[N];//每个元素的父节点
int find(int x)//返回祖宗节点+路径压缩
{
    if(p[x] != x) p[x] = find(p[x]);
    return p[x];
}

int main()
{
    int n , m;
    scanf("%d%d",&n,&m);
    for(int i = 1 ; i <= n ; i ++ ) p[i] = i ; //初始化,使头结点指向自己
    while(m--)
    {
        char op[2];
        int a,b;
        scanf("%s%d%d",op,&a,&b);
        if(op[0] == 'M') p[find(a)] = find(b);
        else {
            if(find(a) == find(b)) puts("Yes");
            else puts("No");
        }
    }
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
	public static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
	public static PrintWriter out = new PrintWriter(new OutputStreamWriter(System.out));
	public static int N = 100010;
	public static int n, m;
	public static int p[] = new int[N]; // 每个元素的父节点

	public static int find(int x) {
		if (p[x] != x)
			p[x] = find(p[x]);
		return p[x];
	}

	public static void main(String[] args) throws Exception {
		String[] str = br.readLine().split(" ");
		n = Integer.parseInt(str[0]);
		m = Integer.parseInt(str[1]);
		for(int i = 1 ; i <= n ; i ++ ) p[i] = i;
		while (m-- > 0) {
			str = br.readLine().split(" ");
			int a, b;
			a = Integer.parseInt(str[1]);
			b = Integer.parseInt(str[2]);
			if (str[0].equals("M")) {
				p[find(a)] = find(b);
			} else {
				if (find(a) == find(b))
					out.println("Yes");
				else
					out.println("No");
			}
		}
		out.flush();
	}
}

原题链接:连通块中点的数量

#include<iostream>
#include<algorithm>
#include<cstring>

using namespace std;
const int N = 1e5;

int n, m;
int p[N];
int s[N];

int find(int x){
    if(p[x] != x) p[x] = find(p[x]);
    return p[x];
}

int main()
{
    cin >> n >> m;
    for(int i = 1 ; i <= n ; i ++ ){
        p[i] = i;
        s[i] = 1;
    } 
    for(int i = 0 ; i < m ; i ++ ){
        string op;
        int a, b;
        cin >> op;
        if(op == "C"){
            cin >> a >> b;
            if(find(a) == find(b)) continue;
            s[find(b)] += s[find(a)];
            p[find(a)] = find(b);            
        }else if(op == "Q1"){
            cin >> a >> b;
            if(find(a) == find(b)) cout << "Yes" << endl;
            else cout << "No" << endl;
        }else {
            cin >> a;
            cout << s[find(a)] << endl;
        }
    }
    
    return 0;
}

食物链

#include<iostream>
#include<algorithm>
#include<cstring>

using namespace std;
const int N = 50010;

int n, m;
int d[N];
int p[N];

int find(int x){
    if(p[x] != x) {
        int u = find(p[x]);
        d[x] += d[p[x]];
        p[x] = u;
    }
    return p[x];
}

int main()
{
    cin >> n >> m ;
    for(int i = 1 ; i <= n ; i ++) p[i] = i;
    int res = 0;
    for(int i = 0 ; i < m ; i ++ ){
        int D, x, y;
        cin >> D >> x >> y;
        if(x > n || y > n) res ++ ;
        else {
            int px = find(x), py = find(y);
            if(D == 1) {
                if(px == py && (d[x] - d[y]) % 3) res ++;
                else if(px != py){
                    p[px] = py;
                    d[px] = d[y] - d[x]; // (d[x] + ? - d[y]) mod 3 == 0;
                }
            }
            else {
                if(px == py && (d[x] - d[y] - 1) % 3) res ++;
                else if(px != py)
                {
                    p[px] = py;
                    d[px] = d[y] + 1 - d[x]; //(d[x] + ? - d[y] - 1) mod 3 == 0;
                }
            }
        }
    }
    
    cout << res << endl;
    return 0;
}

模板:

// h[N]存储堆中的值, h[1]是堆顶,x的左儿子是2x, 右儿子是2x + 1
// ph[k]存储第k个插入的点在堆中的位置
// hp[k]存储堆中下标是k的点是第几个插入的
int h[N], ph[N], hp[N], size;

// 交换两个点,及其映射关系
void heap_swap(int a, int b)
{
    swap(ph[hp[a]],ph[hp[b]]);
    swap(hp[a], hp[b]);
    swap(h[a], h[b]);
}

void down(int u)
{
    int t = u;
    if (u * 2 <= size && h[u * 2] < h[t]) t = u * 2;
    if (u * 2 + 1 <= size && h[u * 2 + 1] < h[t]) t = u * 2 + 1;
    if (u != t)
    {
        heap_swap(u, t);
        down(t);
    }
}

void up(int u)
{
    while (u / 2 && h[u] < h[u / 2])
    {
        heap_swap(u, u / 2);
        u >>= 1;
    }
}

// O(n)建堆
for (int i = n / 2; i; i -- ) down(i);

原题链接:https://www.acwing.com/problem/content/840/

#include<iostream>
#include<algorithm>
#include<string.h>

using namespace std;

const int N = 100010;
int h[N],hp[N],ph[N],mysize;

void heap_swap(int a,int b)
{
    swap(ph[hp[a]],ph[hp[b]]);
    swap(hp[a],hp[b]);
    swap(h[a],h[b]);
}

void down(int u)
{
    int t = u;
    if(u*2 <= mysize && h[u*2] <= h[t]) t = u*2;
    if(u*2+1 <= mysize && h[u*2+1] <= h[t]) t = u*2 +1;
    if(u!=t)
    {
        heap_swap(u,t);
        down(t);
    }
}

void up(int u)
{
    while(u/2 && h[u/2] > h[u])
    {
        heap_swap(u,u/2);
        u /= 2;
    }
}


int main()
{
    int n,m = 0;
    scanf("%d",&n);
    while(n--)
    {
        char op[10];
        scanf("%s",op);
        int x,k;
        if(!strcmp(op,"I"))
        {
            scanf("%d",&x);
            mysize ++ ;
            m ++ ;
            ph[m] = mysize , hp[mysize] = m;
            h[mysize] = x;
            up(mysize);
        }
        else if(!strcmp(op,"PM")) printf("%d\n",h[1]);
        else if(!strcmp(op,"DM"))
        {
            heap_swap(1,mysize);
            mysize -- ;
            down(1);
        }
        else if(!strcmp(op,"D"))
        {
            scanf("%d",&k);
            k = ph[k];
            heap_swap(k,mysize);
            mysize -- ;
            down(k), up(k);
        }
        else 
        {
            scanf("%d%d",&k,&x);
            k = ph[k];
            h[k] = x;
            down(k),up(k);
        }        
    }
    return 0;
}

哈希表

取模的数要取质数,且离2的整数次幂最远

拉链法

模板:

(1) 拉链法
    int h[N], e[N], ne[N], idx;

    // 向哈希表中插入一个数
    void insert(int x)
    {
        int k = (x % N + N) % N;
        e[idx] = x;
        ne[idx] = h[k];
        h[k] = idx ++ ;
    }

    // 在哈希表中查询某个数是否存在
    bool find(int x)
    {
        int k = (x % N + N) % N;
        for (int i = h[k]; i != -1; i = ne[i])
            if (e[i] == x)
                return true;

        return false;
    }

(2) 开放寻址法
    int h[N];

    // 如果x在哈希表中,返回x的下标;如果x不在哈希表中,返回x应该插入的位置
    int find(int x)
    {
        int t = (x % N + N) % N;
        while (h[t] != null && h[t] != x)
        {
            t ++ ;
            if (t == N) t = 0;
        }
        return t;
    }

原题链接:https://www.acwing.com/problem/content/842/

#include<iostream>
#include<cstring>
using namespace std;

const int N = 100003;//大于十万的第一个质数

int h[N];//哈希表的槽
int e[N],ne[N],idx;//每个槽的链表

void insert(int x)
{
    int k = (x % N + N) % N ;//让余数变成正数
    
    e[idx] = x;
    ne[idx] = h[k];
    h[k] = idx ++ ;
}

bool find(int x)
{
    int k = (x % N + N) % N;
    for(int i = h[k] ; i != -1 ; i = ne[i])
    {
        if(e[i] == x) return true;
    }
    return false;
}

int main()
{
    int n;
    scanf("%d",&n);
    
    memset(h,-1,sizeof h);
    while(n--)
    {
        char op[2];
        int x;
        scanf("%s%d",op,&x);
        if(*op == 'I') insert(x);
        else 
        {
            if(find(x)) puts("Yes");
            else puts("No");
        }
    }
    return 0;
}

开放寻址法

#include<iostream>
#include<cstring>
using namespace std;

const int N = 200003, null = 0x3f3f3f3f;//大于二十万的第一个质数

int h[N];//哈希表的槽

//若存在元素,find函数返回位置,若不存在,则返回应该存储的位置
int find(int x)
{
    int k = (x % N + N)%N;
    while(h[k] != null && h[k] != x)//当前位置不为空,且元素不是他自己
    {
        k ++ ;
        if(k == N) k = 0;//若k走到最后的一个位置,则从头开始找
    }
    return k;
}

int main()
{
    int n;
    scanf("%d",&n);
    
    memset(h,0x3f,sizeof h);//memset是按字节进行
    while(n--)
    {
        char op[2];
        int x;
        scanf("%s%d",op,&x);
        int k = find(x);
        if(*op == 'I') h[k] = x;
        else 
        {
            if(h[k] != null) puts("Yes");
            else puts("No");
        }
    }
    return 0;
}

字符串前缀哈希法

模板:

核心思想:将字符串看成P进制数,P的经验值是131或13331,取这两个值的冲突概率低
小技巧:取模的数用2^64,这样直接用unsigned long long存储,溢出的结果就是取模的结果

typedef unsigned long long ULL;
ULL h[N], p[N]; // h[k]存储字符串前k个字母的哈希值, p[k]存储 P^k mod 2^64

// 初始化
p[0] = 1;
for (int i = 1; i <= n; i ++ )
{
    h[i] = h[i - 1] * P + str[i];
    p[i] = p[i - 1] * P;
}

// 计算子串 str[l ~ r] 的哈希值
ULL get(int l, int r)
{
    return h[r] - h[l - 1] * p[r - l + 1];
}

原题链接:https://www.acwing.com/problem/content/843/

预处理字符串前缀的哈希

Q:如何定义某一个前缀的哈希值

A:把字符串看成一个P进制的数

如果实现取字符串的两个子串匹配,每取一次就计算一次哈希值的时间复杂度也会到O(n)。最好的方法就是预处理求前缀和,把时间复杂度降低为O(1)。
当得到前缀和数组后,每次对[l, r]区间字符串的访问就可以直接得到其哈希值了。
对应的两个公式分别是:

预处理:h[i] = h[i - 1] * P + str[i]
每次查询[l, r]:h[r] - h[l - 1] * Pr-l+1
注:为什么不是 h[r] - h[l-1] 而是 h[r] - h[l - 1] * pr-l+1?理由如下:
给定字符串ABCDE,想知道ABC和DE,那么看成P进制的数之后,DE = ABCDE - ABC00(其中:ABC00 = ABC * p2),而不是ABCDE - ABC。

#include<iostream>

using namespace std;

typedef unsigned long long ULL;

const int N = 100010, P = 131;

int n,m;
char str[N];
ULL h[N],p[N];

ULL get(int l,int r)
{
    return h[r] - h[l-1] * p[r-l+1];
}
int main()
{
    scanf("%d%d%s",&n,&m,str+1);
    p[0] = 1;
    for(int i = 1 ; i <= n ; i ++ )
    {
        p[i] = p[i-1] * P;
        h[i] = h[i-1] * P + str[i];
    }
    
    while(m--)
    {
        int l1,r1,l2,r2;
        scanf("%d%d%d%d",&l1,&r1,&l2,&r2);
        if(get(l1,r1) == get(l2,r2)) puts("Yes");
        else puts("No");
    }
    
    return 0;
} 

stl

vector, 变长数组,倍增的思想
    size()  返回元素个数
    empty()  返回是否为空
    clear()  清空
    front()/back() 第一个数/最后一个数
    push_back()/pop_back() 向最后插入一个数/删掉最后一个数
    begin()/end() 第0个数/最后一个数的后面一个数 即a[0]/a[a.size()]
    [] 支持随机寻址
    支持比较运算,按字典序

pair<int, int>
    first, 第一个元素
    second, 第二个元素
    支持比较运算,以first为第一关键字,以second为第二关键字(字典序)

string,字符串
    size()/length()  返回字符串长度
    empty()
    clear()
    substr(起始下标,(子串长度))  返回子串
    c_str()  返回字符串所在字符数组的起始地址

queue, 队列
    size()
    empty()
    push()  向队尾插入一个元素
    front()  返回队头元素
    back()  返回队尾元素
    pop()  弹出队头元素

priority_queue, 优先队列,默认是大根堆
    size()
    empty()
    push()  插入一个元素
    top()  返回堆顶元素
    pop()  弹出堆顶元素
    定义成小根堆的方式:priority_queue<int, vector<int>, greater<int>> q;

stack, 栈
    size()
    empty()
    push()  向栈顶插入一个元素
    top()  返回栈顶元素
    pop()  弹出栈顶元素

deque, 双端队列
    size()
    empty()
    clear()
    front()/back()
    push_back()/pop_back()
    push_front()/pop_front()
    begin()/end()
    []

set, map, multiset, multimap, 基于平衡二叉树(红黑树),动态维护有序序列
    size()
    empty()
    clear()
    begin()/end()
    ++, -- 返回前驱和后继,时间复杂度 O(logn)

    set/multiset
        insert()  插入一个数
        find()  查找一个数
        count()  返回某一个数的个数
        erase()
            (1) 输入是一个数x,删除所有x   O(k + logn)
            (2) 输入一个迭代器,删除这个迭代器
        lower_bound()/upper_bound()
            lower_bound(x)  返回大于等于x的最小的数的迭代器
            upper_bound(x)  返回大于x的最小的数的迭代器
    map/multimap
        insert()  插入的数是一个pair
        erase()  输入的参数是pair或者迭代器
        find()
        []  注意multimap不支持此操作。 时间复杂度是 O(logn)
        lower_bound()/upper_bound()

unordered_set, unordered_map, unordered_multiset, unordered_multimap, 哈希表
    和上面类似,增删改查的时间复杂度是 O(1)
    不支持 lower_bound()/upper_bound(), 迭代器的++,--

bitset, 圧位
    最主要特点是可以省空间
    bitset<10000> s; bit<个数>
    ~, &, |, ^
    >>, << 移位操作
    ==, !=
    []

    count()  返回有多少个1

    any()  判断是否至少有一个1
    none()  判断是否全为0

    set()  把所有位置成1
    set(k, v)  将第k位变成v
    reset()  把所有位变成0
    flip()  等价于~
    flip(k) 把第k位取反

vector

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>

using namespace std;

int main()
{
    vector<int> a(10,3); //定义长度为10的数组并初始化为3
    vector<int> a;
    
    for(int i = 0; i< 10;i++) a.push_back(i);
    for(int i = 0; i < a.size();i++) cour << a[i] << ' ';
    cout << endl;
    
    for(vector<int>::iterator i = a.begin(); i ++ ) cout << *i << ' ';
    cout << endl;
    
    for(auto x : a) cout << x << ' ';
    cout << endl; //auto自动推断变量类型
    
    //支持比较
    /*
    如果两个vector对象的容量不同,但是相同位置上的元素值都一样,则元素较少的vector对象小于元素较多的vector对象。
若元素的值有区别,则vector对象的大小关系由第一对相异的元素值的大小关系决定。
*/
    vector<int> a(4,3) , b(3,4);
    if(a<b) puts("a<b");
    return 0;
    
}

系统为某个程序分配空间时,所需时间与空间大小无关,与申请次数有关

vector优化:数组需要变长时,就把数组元素×2——倍增的思想

pair

pair<int,string>p;
p.first;
p.second;

p = make_pair(10,"yxc");
p = {20, "abc"};
pair<int,pair<int,int>>p;

string

string a = "yxc";
a += "def";
a += 'c';

cout << a.substr(1,2) << endl //xc

queue

queue<int> q;
q = queue<int>();

priority_queue

优先队列,默认是大根堆

priority_queue<int> heap;

//插入时按负数插入,可实现小根堆
heap.push(-x);

//直接定义小根堆
priority_queue<int,vector<int>,greater<int>> heap;

算法提高课

并查集

格子游戏

#include<iostream>
#include<cstring>
#include<algorithm>

using namespace std;
const int N = 40010;
int n, m;
int p[N];

int get(int a,int b){
    return a * n + b;
}

int find(int x){
    if(p[x] != x) p[x] = find(p[x]);
    return p[x];
}

int main()
{
    cin >> n >> m ;
    for(int i = 0 ; i < n * n ; i ++ ) p[i] = i;
    
    int res = 0;
    for(int i = 1 ; i <= m ; i ++ ){
        int x, y;
        char d;
        cin >> x >> y >> d;
        x-- , y--;
        int a = get(x, y);
        int b;
        if(d == 'D') b = get(x + 1, y);
        else b = get(x, y + 1);
        
        int pa = find(a), pb = find(b);
        if(pa == pb){
          res = i;
          break;
        } 
        p[pa] = pb;
    }
    if(!res) cout << "draw" << endl;
    else cout << res << endl;
    return 0;
}

搭配购买

#include<iostream>
#include<cstring>
#include<algorithm>

using namespace std;
const int N = 10010;

int n, m, vol;
int v[N], w[N];
int p[N];
int f[N];

int find(int x){
    if(p[x] != x) p[x] = find(p[x]);
    return p[x];
}

int main()
{
    cin >> n >> m >> vol;
    for(int i = 1 ; i <= n ; i ++ ) p[i] = i;
    for(int i = 1 ; i <= n ; i ++ ){
        cin >> v[i] >> w[i];
    }
    
    for(int i = 0 ; i < m ; i ++ ){
        int a, b;
        cin >> a >> b;
        int pa = find(a);
        int pb = find(b);
        if(pa != pb){
            v[pb] += v[pa];
            w[pb] += w[pa];
            p[pa] = pb;
        }
    }
    
    //01背包
    for(int i = 1 ; i <= n ; i++ )
        if(p[i] == i)
        {
            for(int j = vol; j >= v[i]; j -- ){
                f[j] = max(f[j], f[j - v[i]] + w[i]);
            }
        }
        
    cout << f[vol] << endl;
    return 0;
}

程序自动分析

#include<iostream>
#include<algorithm>
#include<cstring>
#include<unordered_map>

using namespace std;
const int N = 200010;

int n,m ;
int p[N];
unordered_map<int,int> S;

struct Query{
    int x, y, e;
}query[N];

int get(int x){
    if(S.count(x) == 0) S[x] = ++n ;
    return S[x];
}

int find(int x){
    if(p[x] != x) p[x] = find(p[x]);
    return p[x];
}

int main()
{
    int T;
    cin >> T;
    while(T --){
        n = 0;
        cin >> m ;
        S.clear();
        for(int i = 0 ; i < m ; i ++ ){
            int x, y, e;
            cin >> x >> y >> e;
            query[i] = {get(x), get(y), e};
        }
        
        for(int i = 1; i <= n ; i ++ ) p[i] = i;
        
        for(int i = 0 ; i < m ; i ++ ){
            if(query[i].e == 1){
                int pa = find(query[i].x);
                int pb = find(query[i].y);
                p[pa] = pb;
            }
        }
        
        bool flag = false;
        for(int i = 0 ; i < m ; i ++ ){
            if(query[i].e == 0){
                int pa = find(query[i].x);
                int pb = find(query[i].y);
                if(pa == pb){
                    flag = true;
                    break;
                }
            }
        }
        
        if(flag) cout << "NO" << endl;
        else cout << "YES" << endl;
    } 
    
    return 0;
}

银河英雄传说

#include<iostream>
#include<algorithm>

using namespace std;
const int N = 30010;

int m;
int p[N], sz[N], d[N];

int find(int x){
    if(p[x] != x){
        int u = find(p[x]);
        d[x] += d[p[x]];
        p[x] = u;
    }
    return p[x];
}

int main()
{
    cin >> m ;
    
    for(int i = 1 ; i < N ; i ++ ){
        p[i] = i;
        sz[i] = 1;
    }
    for(int i = 0 ; i < m ; i ++ ){
        char op[2];
        int a, b;
        cin >> op >> a >> b;
        if(op[0] == 'M'){
            int pa = find(a), pb = find(b);
            if(pa != pb){
                d[pa] = sz[pb];
                sz[pb] += sz[pa];
                p[pa] = pb;                
            }
        }else {
            int pa = find(a), pb = find(b);
            if(pa != pb) cout << "-1" << endl;
            else cout << max(abs(d[a] - d[b]) - 1 , 0) << endl;
        }
    }
    
    return 0;
}

奇偶游戏

带边权写法

#include<iostream>
#include<cstring>
#include<algorithm>

using namespace std;
const int N = 20010;

int n, m;
int p[N], d[N];
unordered_map<int,int> S;

int get(int x){
    if(S.count(x) == 0) S[x] = ++ n;
    return S[x];
}

int find(int x){
    if(p[x] != x){
        int u = find(p[x]);
        d[x] += d[p[x]];
        p[x] = u;
    }
    
    return p[x];
}

int main()
{
    cin >> n >> m;
    n = 0;
    for(int i = 0 ; i < N ; i++ ) p[i] = i;
    int res = m;
    for(int i = 1 ; i <= m ; i ++ ){
        int a, b;
        string type;
        cin >> a >> b >> type;
        a = get(a - 1), b = get(b);
        
        int t = 0;
        if(type == "odd") t = 1;
        int pa = find(a), pb = find(b);
        if(pa == pb){
            if(((d[a] + d[b]) % 2 + 2) % 2 != t)
            {
                res = i - 1 ;
                break;
            }
        }
        else {
            p[pa] = pb;
            d[pa] = d[a] ^ d[b] ^ t;
        }
    }
    
    cout << res << endl;
    
    return 0;
}

扩展域写法

#include <cstring>
#include <iostream>
#include <algorithm>
#include <unordered_map>

using namespace std;

const int N = 40010, Base = N / 2;

int n, m;
int p[N];
unordered_map<int, int> S;

int get(int x)
{
    if (S.count(x) == 0) S[x] = ++ n;
    return S[x];
}

int find(int x)
{
    if (p[x] != x) p[x] = find(p[x]);
    return p[x];
}

int main()
{
    cin >> n >> m;
    n = 0;

    for (int i = 0; i < N; i ++ ) p[i] = i;

    int res = m;
    for (int i = 1; i <= m; i ++ )
    {
        int a, b;
        string type;
        cin >> a >> b >> type;
        a = get(a - 1), b = get(b);

        if (type == "even")
        {
            if (find(a + Base) == find(b))
            {
                res = i - 1;
                break;
            }
            p[find(a)] = find(b);
            p[find(a + Base)] = find(b + Base);
        }
        else
        {
            if (find(a) == find(b))
            {
                res = i - 1;
                break;
            }

            p[find(a + Base)] = find(b);
            p[find(a)] = find(b + Base);
        }
    }

    cout << res << endl;

    return 0;
}

树状数组

楼兰图腾

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

const int N = 2000010;

typedef long long LL;

int n;
//t[i]表示树状数组i结点覆盖的范围和
int a[N], t[N];
//Lower[i]表示左边比第i个位置小的数的个数
//Greater[i]表示左边比第i个位置大的数的个数
int Lower[N], Greater[N];

//返回非负整数x在二进制表示下最低位1及其后面的0构成的数值
int lowbit(int x)
{
    return x & -x;
}

//将序列中第x个数加上k。
void add(int x, int k)
{
    for(int i = x; i <= n; i += lowbit(i)) t[i] += k;
}
//查询序列前x个数的和
int ask(int x)
{
    int sum = 0;
    for(int i = x; i; i -= lowbit(i)) sum += t[i];
    return sum;
}

int main()
{

    scanf("%d", &n);
    for(int i = 1; i <= n; i++) scanf("%d", &a[i]);

    //从左向右,依次统计每个位置左边比第i个数y小的数的个数、以及大的数的个数
    for(int i = 1; i <= n; i++)
    {
        int y = a[i]; //第i个数

        //在前面已加入树状数组的所有数中统计在区间[1, y - 1]的数字的出现次数
        Lower[i] = ask(y - 1); 

        //在前面已加入树状数组的所有数中统计在区间[y + 1, n]的数字的出现次数
        Greater[i] = ask(n) - ask(y);

        //将y加入树状数组,即数字y出现1次
        add(y, 1);
    }

    //清空树状数组,从右往左统计每个位置右边比第i个数y小的数的个数、以及大的数的个数
    memset(t, 0, sizeof t);

    LL resA = 0, resV = 0;
    //从右往左统计
    for(int i = n; i >= 1; i--)
    {
        int y = a[i];
        resA += (LL)Lower[i] * ask(y - 1);
        resV += (LL)Greater[i] * (ask(n) - ask(y));

        //将y加入树状数组,即数字y出现1次
        add(y, 1);
    }

    printf("%lld %lld\n", resV, resA);

    return 0;
}

文章作者: wck
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 wck !
评论
  目录