{10, 20, 10, 30, 20, 50} → LIS = 4 {10, 20, 30, 50}
핵심 알고리즘
Top-down
static int LIS(int N) {
// 만약 탐색하지 않던 위치의 경우
if (dp[N] == null) {
dp[N] = 1; // 1로 초기화
// N 이전의 노드들을 탐색
for (int i = N-1; i >= 0; i--) {
// 이전의 노드 중 seq[N]의 값보다 작은 걸 발견했을 경우
if (seq[i] < seq[N]) {
dp[N] = Math.max(dp[N], LIS(i) + 1);
}
}
}
return dp[N];
}
Java
복사
1) 먼저 N번째 값에 대해 이전에 탐색한 결과물이 있는지를 검사 → 만약 없다면 탐색하지 않았다는 의미이므로 dp[N]을 1로 초기화 → 모든 부분수열의 길이는 최소한 1 이상 이기 떄문
2) 그 후, N-1부터 0까지 N보다 작은 노드들을 탐색하면서, 해당 노드의 값이 N번째 값보다 작은 경우를 찾는다
3) i = 1에서 현재 탐색하고자 하는 값 dp[4] = Max(dp[4], LIS(2)+1) 을 갱신
→ 여기서 +1 하는 이유는 dp[N]이 이전 부분수열에 N번째 원소가 추가되었다는 의미이기 때문이다
4) dp[4]와 LIS(2)+1 중, 큰 값은 LIS(2) + 1 (부분수열 {10, 20}), 즉 2이기 때문에, dp[4]는 2로 갱신된다
5) 그리고 반복문을 마저 탐색하면 i=0일 때, seq[0] < seq[4]를 만족하므로, 재귀탐색을 하게 되는데, 마찬가지로 LIS(0) 또한 1 이므로 최대 길이기 2가된다
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
public class Main {
static Integer[] dp;
static int[] arr;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int num = Integer.parseInt(br.readLine());
dp = new Integer[num];
arr = new int[num];
StringTokenizer st = new StringTokenizer(br.readLine(), " ");
for (int i = 0; i < num; i++) {
arr[i] = Integer.parseInt(st.nextToken());
}
// 0 ~ num-1 까지 탐색
for (int i = 0; i < num; i++) {
LIS(i);
}
// 최댓값 찾기
int max = dp[0];
for (int i = 1; i < num; i++) {
max = Math.max(max, dp[i]);
}
System.out.println(max);
}
static int LIS(int num) {
if (dp[num] == null) {
dp[num] = 1;
for (int i = num-1; i >= 0; i--) {
if (arr[i] < arr[num]) {
dp[num] = Math.max(dp[num], LIS(i)+1);
}
}
}
return dp[num];
}
}
Java
복사
Bottom-up
for (int i = 0; i < N; i++) {
dp[i] = 1;
// 0 ~ i 이전 원소들 탐색
for (int j = 0; j < i; j++) {
// j번째 원소가 i번째 원소보다 작으면서, i번째 dp가 dp+1 값보다 작은 경우
if (seq[j] < seq[i] && dp[i] < dp[j] + 1) {
dp[i] = dp[j] + 1; // j번째 원소의 +1 값이 i번째 dp가 된다
}
}
}
Java
복사
import sys
sys.setrecursionlimit(10**6)
num = int(input())
arr = list(map(int, sys.stdin.readline().split()))
dp = [1] * num
for i in range(num):
for j in range(i):
if arr[j] < arr[i]:
dp[i] = max(dp[i], dp[j]+1)
print(max(dp))
Python
복사