Rによる能動的推論モデル入門
本スライドに関して,間違っている・間違ってそう・ご不明な内容がありましたら,国里愛彦の研究室メールフォームからお知らせください(共同研究もお待ちしています)
\[ p(s|o) = \frac{p(o|s)p(s)}{p(o)} \]
→状態空間モデルの枠組みで各種ベイズ推論モデルを位置づけて理解する。
→変動性を明示的に扱う
\[ F = \int dx q(s)log \frac{q(s)}{p(s,o)} \]
\[ F = D_{KL}[q(s)\parallel p(s|o)]- \log p(o) \]
→\(D_{KL}\)はゼロ以上なので,\(\textbf{F}\)がサプライザルの上限を提供。\(\textbf{F}\)を最小化する\(q\)の探索を通してサプライザルを最小化(推論問題を\(\textbf{F}\)を最小化する最適化問題へ)
→階層的神経回路による予測誤差最小化で,\(\textbf{F}\)を最小化
→変分自由エネルギー\(\textbf{F}\)ではなく期待自由エネルギー\(\textbf{G}\)を計算する
\[ G(\pi) = - \mathbb{E}_{q(\tilde{s},\tilde{o}|\pi)}[D_{KL}[q(\tilde{s}|\tilde{o},\pi) \parallel q(\tilde{s}|\pi)]]- \mathbb{E}_{q(\tilde{o}|\pi)}[\log p(\tilde{o}|C)] \]
\[ \begin{align} \textbf{G}_{\pi} = \textbf{H} \cdot \textbf{s}_{\pi \tau} + \textbf{o}_{\pi \tau} \cdot \varsigma_{\pi \tau} \\ \varsigma_{\pi \tau} = \log \textbf{o}_{\pi \tau} - \log \textbf{C}_{\tau} \\ \textbf{H} = -diag(\textbf{A}\cdot \log\textbf{A}) \end{align} \]
\[ \begin{align} F = \pi \cdot \textbf{F} \\ \textbf{F}_{\pi} = \Sigma_{\tau}\textbf{F}_{\pi \tau} \\ \textbf{F}_{\pi \tau} = \textbf{s}_{\pi \tau} \cdot (\log \textbf{s}_{\pi \tau} - \log \textbf{A}\cdot o_{\tau} - \log \textbf{B}_{\pi \tau}\textbf{s}_{\pi \tau-1}) \end{align} \]
library()
でパッケージ読み込み)# expand.gridを使ってグリッド位置を作る関数
create_grid_locations <- function() {
grid <- expand.grid(y = 1:3,x = 1:3)
grid_locations <- Map(c, grid$x, grid$y)
return(grid_locations)
}
# グリッドをプロットする関数
plot_grid <- function(grid_locations, num_x = 3, num_y = 3) {
# 場所を用意
grid_heatmap <- matrix(0, nrow = num_y, ncol = num_x)
counter <- 1
for (i in 1:num_y) {
for (j in 1:num_x) {
grid_heatmap[i, j] <- counter
counter <- counter + 1
}
}
# データフレームに変換
plot_data <- data.frame(
y = rep(1:num_y, num_x),
x = rep(1:num_x, each = num_y),
value = as.vector(grid_heatmap)
)
# ヒートマップの色の設定
custom_colors <- scale_fill_gradient(low = "#E8F4D9",high = "#1A365D")
# ヒートマップの作成
ggplot(plot_data, aes(x = x, y = y, fill = value)) +
geom_tile() +
geom_text(aes(label = sprintf("%.0f", value)),
size = 6,
color = ifelse(plot_data$value > 5, "white", "black")) +
custom_colors +
scale_y_reverse() +
theme_minimal(base_size = 15) +
theme(legend.position = "none",
axis.title.x = element_blank(),
axis.title.y = element_blank()) +
coord_fixed()
}
#
diag()
で,行列の対角成分が1になるように設定)# 2次元尤度行列をヒートマップとしてプロットする関数
plot_likelihood <- function(matrix, xlabels = 1:9, ylabels = 1:9, title_str = "Likelihood distribution (A)") {
# 列方向の和が1かどうかを確認
if (!all(near(colSums(matrix), 1.0))) {
stop("Distribution not column-normalized! Please normalize (ensure colSums(matrix) == 1.0 for all columns)")
}
# 行列をデータフレームに変換
df <- as.data.frame(matrix)
colnames(df) <- xlabels
rownames(df) <- ylabels
df <- df %>%
mutate(y = factor(rownames(df), levels = rev(rownames(df)))) %>%
pivot_longer(cols = -y, names_to = "x", values_to = "value")
# ggplot2でヒートマップを作成
ggplot(df, aes(x = x, y = y, fill = value)) +
geom_tile() +
scale_fill_gradient(low = "black", high = "white", limits = c(0, 1)) +
labs(title = title_str, x = "隠れ状態の位置", y = "観測の位置") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
}
#
# アクション定義
actions <- c("UP", "DOWN", "LEFT", "RIGHT", "STAY")
# B行列作成関数
create_B_matrix <- function() {
num_states <- length(grid_locations)
num_actions <- length(actions)
# 3次元配列の初期化
B <- array(0, dim = c(num_states, num_states, num_actions))
# 各アクションと状態に対する遷移を計算
for (action_id in seq_along(actions)) {
action_label <- actions[action_id]
for (curr_state in seq_along(grid_locations)) {
# 現在の位置を取得
grid_location <- grid_locations[[curr_state]]
y <- grid_location[1]
x <- grid_location[2]
# アクションに基づいて次の位置を計算
if (action_label == "UP") {
next_y <- ifelse(y > 1, y - 1 ,y)
next_x <- x
} else if (action_label == "DOWN") {
next_y <- ifelse(y < 3, y + 1, y)
next_x <- x
} else if (action_label == "LEFT") {
next_x <- ifelse(x > 1, x - 1, x)
next_y <- y
} else if (action_label == "RIGHT") {
next_x <- ifelse(x < 3, x + 1, x)
next_y <- y
} else if (action_label == "STAY") {
next_x <- x
next_y <- y
}
# 新しい位置を作成
new_location <- c(next_y, next_x)
# 新しい位置のインデックスを検索
next_state <- which(sapply(grid_locations, function(loc)
all(loc == new_location)))
# 遷移確率を設定
B[next_state, curr_state, action_id] <- 1.0
}
}
return(B)
}
#
plot_point_on_grid <- function(state_vector, grid_locations) {
# 状態のベクトルから状態のインデックスを取得
state_index <- which(state_vector == 1)
# グリッドの位置からxとyの値を抽出
coords <- grid_locations[[state_index]]
x <- coords[1]
y <- coords[2]
# 空のグリッドを用意してグリッドの位置に値をいれる
grid_heatmap <- matrix(0, nrow = 3, ncol = 3)
grid_heatmap[y, x] <- 1.0
# プロット用のデータ用意
plot_data <- data.frame(
y = rep(1:nrow(grid_heatmap), each = ncol(grid_heatmap)),
x = rep(1:ncol(grid_heatmap), times = nrow(grid_heatmap)),
value = as.vector(grid_heatmap),
index = 1:9
)
ggplot(plot_data, aes(x = x, y = y, fill = value)) +
geom_tile() +
geom_text(aes(label = sprintf("%.0f", index)),
size = 6,
color = "white") +
scale_fill_gradient(low = "black", high = "grey", limits = c(0, 1)) +
theme_minimal() +
theme(legend.position = "none") +
scale_y_reverse() +
theme(axis.title.x = element_blank(),
axis.title.y = element_blank()) +
coord_fixed()
}
onehot <- function(index, length) {
vec <- numeric(length)
vec[index] <- 1
return(vec)
}
#
plot_beliefs <- function(belief_dist, title_str = "") {
# プロット用データの用意
plot_data <- data.frame(
x = 1:length(belief_dist),
probability = belief_dist
)
ggplot(plot_data, aes(x = x, y = probability)) +
geom_bar(stat = "identity", fill = "red") +
scale_x_continuous(breaks = 1:(length(belief_dist))) +
scale_y_continuous(limits = c(0, 1)) +
labs(title = title_str,x = "State",y = "Probability")
}
# Cを格納する空のベクトルを作成
C <- numeric(n_observations)
# 望ましい報酬のある位置(Y=3,X=3)とそのインデックス(9)を用意
desired_location <- c(3,3)
desired_location_index <- which(sapply(grid_locations, function(x) all(x == desired_location)))
# 望ましい位置の選好を100%(1.0)に設定
C[desired_location_index] <- 1.0
# 事前選好の分布を表示
plot_beliefs(C, title_str = "Preferences over observations")
\[ q(s_t) = \sigma\left(\ln \mathbf{A}[o,:] + \ln\mathbf{B}[:,:,a] \cdot q(s_{t-1})\right) \]
# 対数関数
log_stable <- function(arr) {
# 非常に小さい値を避けるための下限値とその場合の置き換え
MIN_VALUE <- 1e-16
arr[arr < MIN_VALUE] <- MIN_VALUE
return(log(arr))
}
# softmax関数
softmax <- function(arr) {
# オーバーフローを防ぐため,最大値を引く
arr_shifted <- arr - max(arr)
exp_arr <- exp(arr_shifted)
return(exp_arr / sum(exp_arr))
}
# 状態推論関数
infer_states <- function(observation_index, A, prior) {
# 尤度の対数を計算(上の式の第1項)
log_likelihood <- log_stable(A[observation_index, ])
# 事前分布の対数を計算(上の式の第2項(Bとsの積は別に計算))
log_prior <- log_stable(prior)
# softmaxを適用して事後分布を計算
qs <- softmax(log_likelihood + log_prior)
return(qs)
}
#
\[ P(s_t) = \mathbf{E}_{q(s_{t-1})}[P(s_t | s_{t-1}, a_{t-1})] \]
infer_states()
関数(観測による状態の更新)を実行する。# 期待される状態を計算する関数
get_expected_states <- function(B, qs_current, action) {
# 特定のアクションが与えられた時の1ステップ先の期待される状態を計算
qs_a <- B[, , action] %*% qs_current
return(qs_a)
}
# 期待される観測を計算する関数
get_expected_observations <- function(A, qs_a) {
# 特定のアクションが与えられた時の1ステップ先の期待される観測を計算
qo_a <- A %*% qs_a
return(qo_a)
}
# エントロピーHを計算する関数
entropy <- function(A) {
H_A <- -colSums(A * log_stable(A))
return(H_A)
}
# KLダイバージェンスを計算する関数
kl_divergence <- function(qo_a, C) {
# 2つの1次元カテゴリカル分布間のKullback-Leiblerダイバージェンスを計算
return(sum((log_stable(qo_a) - log_stable(C)) * qo_a))
}
#
get_expected_states()
:期待される状態 \(Q(s_{t+1}|a_t)\) を,特定の行為に対する状態遷移行列*現在の状態で計算get_expected_observations()
:期待される観測 \(Q(o_{t+1}|a_t)\) を,尤度*現在の状態で計算entropy()
:尤度 \(P(o|s)\) のエントロピー( \(\mathbf{H}\left[\mathbf{A}\right]\) )を, \(-diag(\textbf{A}\cdot ln \textbf{A})\) で計算kl_divergence()
:期待される観測と事前選好 \(\mathbf{C}\) とのKLダイバージェンスを, \((ln\textbf{o}_{\pi} - ln\textbf{C})\cdot \textbf{o}_{\pi}\) で計算\[ G(\pi) = - \mathbb{E}_{q(\tilde{s},\tilde{o}|\pi)}[D_{KL}[q(\tilde{s}|\tilde{o},\pi) \parallel q(\tilde{s}|\pi)]]- \mathbb{E}_{q(\tilde{o}|\pi)}[\log p(\tilde{o}|C)] \\ = \mathbb{E}_{q(\tilde{s},\tilde{o}|\pi)}[H[p(\tilde{o}|\tilde{s})]]-D_{KL}[q(\tilde{o}|\pi) \parallel p(\tilde{o}|C)] \]
\[ G(\pi) = - \mathbb{E}_{q(\tilde{s},\tilde{o}|\pi)}[D_{KL}[q(\tilde{s}|\tilde{o},\pi) \parallel q(\tilde{s}|\pi)]]- \mathbb{E}_{q(\tilde{o}|\pi)}[\log p(\tilde{o}|C)] \\ = \mathbb{E}_{q(\tilde{s},\tilde{o}|\pi)}[H[p(\tilde{o}|\tilde{s})]]-D_{KL}[q(\tilde{o}|\pi) \parallel p(\tilde{o}|C)] \]
\[ G(\pi) = - \mathbb{E}_{q(\tilde{s},\tilde{o}|\pi)}[D_{KL}[q(\tilde{s}|\tilde{o},\pi) \parallel q(\tilde{s}|\pi)]]- \mathbb{E}_{q(\tilde{o}|\pi)}[\log p(\tilde{o}|C)] \\ = \mathbb{E}_{q(\tilde{s},\tilde{o}|\pi)}[H[p(\tilde{o}|\tilde{s})]]-D_{KL}[q(\tilde{o}|\pi) \parallel p(\tilde{o}|C)] \]
# 右に移動する場合の期待される状態を計算
qs_a_right <- get_expected_states(B, qs_current, right_idx)
# 右に移動する場合期待される観測を計算
qo_a_right <- get_expected_observations(A, qs_a_right)
# 曖昧さの期待値
predicted_uncertainty_right <- sum(H_A * qs_a_right)
# リスク
predicted_divergence_right <- kl_divergence(qo_a_right, C)
# 右に移動する場合のGを計算
G[2] <- predicted_uncertainty_right + predicted_divergence_right
cat('曖昧さの期待値:', predicted_uncertainty_right, ',リスク', predicted_divergence_right, ',期待自由エネルギーG', G[2])
曖昧さの期待値: 0 ,リスク 0 ,期待自由エネルギーG 0
R6Class()
関数を使用する。initialize
メソッドは,クラスのインスタンスが作成されるときに実行される関数。ここでは,インスタンス作成時に入力されるname, ageを受け取って,フィールドに格納(その際にself$XXXを使う)。greet
のように,関数を設定しておくと,インスタンスで使えるようになる。オブジェクト名$new()
メソッドを使用する。オブジェクト名$関数名()
を使用する。initialize
メソッドで\(\textbf{ABCD}\)行列をいれる。infer_states
で状態の推定,calc_G
で期待自由エネルギーの計算,sample_action
でアクションのサンプリングAIagent <- R6Class("AIagent",
public = list(
# フィールドの定義
A = NULL,
B = NULL,
C = NULL,
D = NULL,
actions = NULL,
prior = NULL,
qs = NULL,
G = NULL,
# 初期化メソッド
initialize = function(A, B, C, D,actions) {
self$A = A
self$B = B
self$C = C
self$prior = D
self$actions = actions
},
# 状態の推定
infer_states = function(obs_index) {
# 尤度の対数を計算(上の式の第1項)
log_likelihood <- log_stable(self$A[obs_index, ])
# 事前分布の対数を計算(上の式の第2項(Bとsの積は別に計算))
log_prior <- log_stable(self$prior)
# softmaxを適用して事後分布を計算
self$qs <- softmax(log_likelihood + log_prior)
return(self$qs)
},
# 期待自由エネルギーGの計算
calc_G = function() {
# 尤度 P(o|s) のエントロピーHを計算
H_A <- -colSums(self$A * log_stable(self$A))
# 各アクションについて以下繰り返し
for(action_i in seq_along(self$actions)) {
# 現在のアクションにおける期待される状態を計算
qs_a <- self$B[, , action_i] %*% self$qs
# 現在のアクションにおける期待される観測を計算
qo_a <- self$A %*% qs_a
# 曖昧さの期待値を計算
pred_uncertainty <- sum(H_A * qs_a)
# リスクを計算
pred_div <- sum((log_stable(qo_a) - log_stable(self$C)) * qo_a)
# これらを合計して期待自由エネルギーGを得る
self$G[action_i] <- pred_uncertainty + pred_div
}
return(self$G)
},
# アクションの事後分布の計算とサンプリング
sample_action = function(){
# アクションの事後分布を計算
Q_a <- softmax(-self$G)
# アクションの確率分布からアクションをサンプリング
chosen_action <- sample(length(Q_a), size = 1, prob = Q_a)
# 次のタイムステップの推論のための事前分布を計算
self$prior <- self$B[, , chosen_action] %*% self$qs
return(chosen_action)
}
)
)
#
# グリッドワールドの環境クラスを定義
GridWorldEnv <- R6Class("GridWorldEnv",
public = list(
# フィールドの定義
init_state = NULL,
current_state = NULL,
# 初期化メソッド
initialize = function(starting_state = c(1,1)) {
self$init_state <- starting_state
self$current_state <- self$init_state
cat('Starting state is', paste(starting_state, collapse=","), '\n')
},
# 1ステップ進めるメソッド
step = function(action_label) {
Y <- self$current_state[1]
X <- self$current_state[2]
if(action_label == "UP") {
Y_new <- if(Y > 1) Y - 1 else Y
X_new <- X
} else if(action_label == "DOWN") {
Y_new <- if(Y < 3) Y + 1 else Y
X_new <- X
} else if(action_label == "LEFT") {
Y_new <- Y
X_new <- if(X > 1) X - 1 else X
} else if(action_label == "RIGHT") {
Y_new <- Y
X_new <- if(X < 3) X + 1 else X
} else if(action_label == "STAY") {
Y_new <- Y
X_new <- X
}
# 新しいグリッド位置を保存
self$current_state <- c(Y_new, X_new)
# エージェントは常に自分がいるグリッド位置を直接観測する
obs <- self$current_state
return(obs)
},
# 環境をリセットするメソッド
reset = function() {
self$current_state <- self$init_state
cat('Re-initialized location to', paste(self$init_state, collapse=","), '\n')
obs <- self$current_state
cat('..and sampled observation', paste(obs, collapse=","), '\n')
return(obs)
}
)
)
#
# 尤度Aを作成(単位行列)
A <- diag(n_states)
# 状態遷移行列Bを作成
B <- create_B_matrix()
# 選好Cを作成(位置9を好む)
desired_loc_idx <- which(sapply(grid_locations, function(x) all(x == c(3,3))))
C <- onehot(desired_loc_idx, n_observations)
# 初期信念Dを作成(位置6からスタートすると信じているとします)
start_loc_idx <- which(sapply(grid_locations, function(x) all(x == c(2,3))))
D <- onehot(start_loc_idx, n_states)
# 可能なアクションのベクトル
actions <- c("UP", "DOWN", "LEFT", "RIGHT", "STAY")
$new()
メソッドを使って,グリッドワールド環境のインスタンス化をする。$new()
メソッドを使って,能動的推論エージェントのインスタンス化をする。agent$XXX()
というやり方でメソッドを実行できる。また,メソッドを実行すれば,必要に応じてフィールドが更新されていく。plot_ai_agent()
を定義する(ただキレイなプロットを書くだけの関数なので,理解する必要はないです)plot_ai_agent <- function(qs_current,state_vector, grid_locations, actions, chosen_action,G, t){
p1 <- plot_beliefs(qs_current, title_str = paste("Beliefs about location at time", t))
# グリッドの位置のプロット
p2 <- plot_point_on_grid(state_vector, grid_locations)
# 期待自由エネルギーと選択した行為のプロット
chosen_data <- data.frame(
actions = actions,
selected_action = ifelse(seq_along(actions) == chosen_action, 1, 0),
G = G)
# 期待自由エネルギーのプロット
p3 <- ggplot(chosen_data, aes(x = actions, y = G)) +
geom_bar(stat = "identity", fill = "lightgreen") +
theme_minimal() +
labs(x = "Actions", y = "Expected Free Energy") +
theme(axis.text.x = element_text(angle = 0, hjust = 0.5))
# 選択した行為のプロット
p4 <- ggplot(chosen_data, aes(x = actions, y = selected_action)) +
geom_bar(stat = "identity", fill = "steelblue") +
ylim(0, 1) + # y軸の範囲を0から1に設定
theme_minimal() +
labs(x = "Actions", y = "Selected action") +
theme(axis.text.x = element_text(angle = 0, hjust = 0.5))
grid.arrange(p1, p2, p3, p4, nrow = 2)
}
#
# 試行番号の設定,環境のリセット,観測のインデックスを取得
t <- 1
obs <- env$reset()
obs_idx <- which(sapply(grid_locations, function(x) all(x == obs)))
# 状態に関する推論を実行
qs_current <- agent$infer_states(obs_idx)
# グリッドワールド上の位置を取得
state_idx <- which(sapply(grid_locations,function(x) all(x == obs)))
state_vector <- onehot(state_idx, n_states)
# 期待自由エネルギーを計算
G <- agent$calc_G()
# アクションのサンプリングとラベル取得
chosen_action <- agent$sample_action()
action_label <- actions[chosen_action]
# プロット
plot_ai_agent(qs_current, state_vector, grid_locations, actions, chosen_action, G, t)
Re-initialized location to 2,3
..and sampled observation 2,3
trial_n <- 5
env <- GridWorldEnv$new(starting_state = c(2,2))
agent <- AIagent$new(A, B, C, D,actions)
obs <- env$reset()
for (t in 1:trial_n) {
obs_idx <- which(sapply(grid_locations, function(x) all(x == obs)))
# 状態に関する推論を実行
qs_current <- agent$infer_states(obs_idx)
# グリッドワールド上の位置を取得
state_idx <- which(sapply(grid_locations,function(x) all(x == obs)))
state_vector <- onehot(state_idx, n_states)
# 期待自由エネルギーを計算
G <- agent$calc_G()
# アクションのサンプリングとラベル取得
chosen_action <- agent$sample_action()
action_label <- actions[chosen_action]
# アクションを踏まえた環境の更新
obs <- env$step(action_label)
# プロット
plot_ai_agent(qs_current, state_vector, grid_locations, actions, chosen_action, G, t)
}
#
env <- GridWorldEnv$new(starting_state = c(1,1))
で実行すると確認できる)階層的推論(深層生成モデル)
(バンディット課題も入れたかったが今回は分量的に断念…またの機会に…)