kivantium活動日記

プログラムを使っていろいろやります

階層ベイズモデルを用いた労働時間と知的好奇心の関係分析

以前@berobero11さんに「StanとRでベイズ統計モデリング」をプレゼントしていただき、本を読んだのですが解析したいデータがなかったため勉強した結果を活かすことができずにいました。

しかし今日Twitterを見ていたら



というまさに階層ベイズモデルを用いた解析にぴったりのテーマが流れてきたのでStanの練習として解析を行ってみました。

モデリング

まずは奥村先生のPIAACデータ解析を読んでください。

この記事では

  • 個人の知的好奇心を個人の労働時間のみから説明する単回帰モデルで分析すると小さな負の相関がある。
  • 知的好奇心が労働時間と国の両方に関係しているという重回帰モデルで分析すると労働時間と知的好奇心の間に小さな正の相関がある。
  • 国ごとに知的好奇心の平均を労働時間の平均で説明する単回帰モデルで分析すると、非常に大きな負の相関がある。

という分析が行われています。

この結果を見ると、各個人の知的好奇心の差異を各個人の労働時間のみから説明するのは難しく、国の違いに由来する差がありそうだと分かります。
このような「説明変数だけでは説明がつかない、グループに由来する差異」をうまく扱うことができるのが階層モデルです。

第8章の階層モデルを見ながら以下のようなモデルを立てました。
{\displaystyle
\begin{align}
ll[n] &\sim \mathrm{Normal}(a[cntry[n]]+b[cntry[n]]\cdot wh[n], \sigma_{ll}) \\
a[k] &\sim \mathrm{Normal}(a_0, \sigma_{a}) \\
b[k] &\sim \mathrm{Normal}(b_0, \sigma_{b}) \\
\end{align}
}

 llは知的好奇心、 wh は労働時間、cntryは国を表します。nは各個人を表す添字、kは各国を表す添字です。
知的好奇心が労働時間の1次関数で説明できるというモデルで分析するところは先の分析と同じですが、階層モデルでは切片aと傾きbがそれぞれ国ごとに別の値を取ると考えます。
そして、各国での切片と傾きはある正規分布から生成されていると仮定し、それぞれのパラメータを推定します。
このようなモデルを用いて推定することでグループに共通する性質a_0, b_0とグループごとの性質a[k], b[k]を分けながらも同時に分析することができます。

解析

RとStanで分析を行います。データは奥村先生のサイトからダウンロードしました。

library(ggplot2)
library(data.table)
library(rstan)
data = fread("all.csv")

wh = as.numeric(data$D_Q10)
ll = as.numeric(data$I_Q04d)
cntry = as.numeric(data$CNTRYID)

df <- data.frame(wh=wh, ll=ll, cntry=cntry)
df <- na.omit(df) # NAを除去
df$cntry <- as.numeric(factor(df$cntry)) # 国のIDを連番に振り直す
set.seed(1234) # シード固定
sampleIndex <- sample(nrow(df), 5000) # 全データを用いると計算が遅すぎたので5000データのみ用いる
df2 <- df[sampleIndex,]
data <- list(N=nrow(df2), K=max(df2$cntry), ll=df2$ll, wh=df2$wh, cntry=df2$cntry)
fit <- stan(file='model.stan', data=data, seed=1234)
fit

Stanのコードを以下に示します。

data {
    int N;
    int K;
    real ll[N];
    real wh[N];
    int<lower=1, upper=K> cntry[N];
}

parameters {
    real a0;
    real b0;
    real a[K];
    real b[K];
    real<lower=0> s_a;
    real<lower=0> s_b;
    real<lower=0> s_ll;
}


model {
    for (k in 1:K) {
        a[k] ~ normal(a0, s_a);
        b[k] ~ normal(b0, s_b);
    }
    for (n in 1:N) {
        ll[n] ~ normal(a[cntry[n]] + b[cntry[n]]*wh[n], s_ll);
    }
}

この分析結果は以下の通りです。

          mean se_mean    sd     2.5%      25%      50%      75%    97.5% n_eff Rhat
a0        3.72    0.00  0.08     3.55     3.66     3.72     3.78     3.88  1789 1.00
b0        0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.00   554 1.00
a[1]      3.72    0.00  0.12     3.48     3.65     3.73     3.80     3.94  4000 1.00
a[2]      3.83    0.00  0.08     3.69     3.78     3.82     3.88     3.98   393 1.00
a[3]      3.30    0.00  0.12     3.07     3.22     3.30     3.38     3.53  4000 1.00
a[4]      4.17    0.00  0.12     3.94     4.09     4.17     4.25     4.39  4000 1.00
a[5]      3.66    0.00  0.11     3.44     3.59     3.66     3.74     3.88  4000 1.00
a[6]      3.79    0.00  0.11     3.58     3.72     3.79     3.85     4.00  4000 1.00
a[7]      4.19    0.00  0.10     3.99     4.13     4.19     4.26     4.37  4000 1.00
a[8]      4.15    0.00  0.12     3.92     4.07     4.14     4.22     4.39  4000 1.00
a[9]      3.64    0.00  0.10     3.45     3.58     3.64     3.71     3.87  4000 1.00
a[10]     4.22    0.00  0.10     4.03     4.16     4.22     4.29     4.44  4000 1.00
a[11]     3.99    0.00  0.10     3.78     3.92     3.99     4.05     4.19  4000 1.00
a[12]     3.88    0.00  0.09     3.68     3.82     3.88     3.94     4.06  4000 1.00
a[13]     3.44    0.00  0.12     3.21     3.36     3.43     3.51     3.69  4000 1.00
a[14]     4.06    0.00  0.11     3.85     3.99     4.06     4.13     4.27  4000 1.00
a[15]     3.27    0.00  0.11     3.07     3.20     3.27     3.34     3.49  4000 1.00
a[16]     4.05    0.00  0.12     3.82     3.98     4.05     4.13     4.28  4000 1.00
a[17]     3.19    0.00  0.11     2.97     3.11     3.19     3.27     3.41  4000 1.00
a[18]     3.01    0.00  0.11     2.79     2.95     3.02     3.09     3.22  4000 1.00
a[19]     3.21    0.00  0.13     2.98     3.12     3.21     3.29     3.47  4000 1.00
a[20]     3.70    0.00  0.09     3.52     3.64     3.70     3.77     3.88  4000 1.00
a[21]     4.09    0.00  0.11     3.86     4.02     4.09     4.16     4.29  4000 1.00
a[22]     3.43    0.00  0.11     3.21     3.37     3.44     3.50     3.62  1250 1.00
a[23]     3.88    0.00  0.10     3.70     3.81     3.88     3.95     4.10  4000 1.00
a[24]     3.65    0.00  0.12     3.41     3.57     3.65     3.74     3.88  4000 1.00
a[25]     3.14    0.00  0.12     2.89     3.06     3.14     3.21     3.37  4000 1.00
a[26]     3.58    0.00  0.12     3.33     3.51     3.58     3.66     3.81  2300 1.00
a[27]     3.22    0.00  0.12     2.97     3.14     3.22     3.30     3.45  4000 1.00
a[28]     4.20    0.00  0.11     3.97     4.12     4.20     4.27     4.42  4000 1.00
a[29]     3.62    0.00  0.13     3.35     3.53     3.61     3.70     3.88  4000 1.00
a[30]     4.21    0.00  0.11     3.99     4.14     4.21     4.29     4.43  4000 1.00
b[1]      0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  1598 1.00
b[2]      0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.00   423 1.00
b[3]      0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[4]      0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  2215 1.00
b[5]      0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[6]      0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[7]      0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  2287 1.00
b[8]      0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  1636 1.00
b[9]      0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[10]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  1054 1.00
b[11]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[12]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[13]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  1819 1.00
b[14]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[15]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.00  1229 1.00
b[16]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[17]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  2314 1.00
b[18]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  1594 1.00
b[19]     0.00    0.00  0.00    -0.01     0.00     0.00     0.00     0.00   626 1.00
b[20]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  1831 1.00
b[21]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  1809 1.00
b[22]     0.00    0.00  0.00     0.00     0.00     0.00     0.01     0.01   801 1.00
b[23]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01   923 1.00
b[24]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[25]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[26]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  2305 1.00
b[27]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[28]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[29]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
b[30]     0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.01  4000 1.00
s_a       0.41    0.00  0.06     0.30     0.36     0.40     0.44     0.54  4000 1.00
s_b       0.00    0.00  0.00     0.00     0.00     0.00     0.00     0.00   221 1.01
s_ll      0.87    0.00  0.01     0.86     0.87     0.87     0.88     0.89  4000 1.00
lp__  -1647.09    1.28 14.78 -1671.89 -1657.53 -1648.84 -1637.76 -1614.22   134 1.03

まず一番右のRhatに注目すると、すべて1.1より小さい値となっています。今回はデフォルトの設定を用いたためchain数が4なので「chain数が3以上ですべてのパラメータでRhat < 1.1となること」(p.42)という収束の条件を満たしています。
そのためこの結果を信用してもよさそうだと判断することができます。(この記事ではStanの説明はしないので本を読んでください)

次にb0に注目すると平均も分散も0.00となっており、全体の傾向としては労働時間が知的好奇心に与える影響はないと言えそうです。各国でのbの値を見てもほとんど0であり、この分析結果から労働時間が知的好奇心に影響を与えると結論づけるのは無理がありそうです。

a0に注目すると分散がbに比べて大きな値を取っているので、個人間の差異はおもに切片の項で説明されると考えるのがよさそうです。

以下にaの国別プロットを示します。
f:id:kivantium:20171116163050p:plain
国ごとに違いがある様子が分かります。

bの国別プロットは以下の通りです。
f:id:kivantium:20171116163131p:plain
全ての国でほとんどゼロであり、国ごとの違いも少ないことが分かります。確かに多くの国で正ではあるようですが、0.01を超えるのは97.5%区間を超えたあたりであることから、この事後分布から正の相関があると言うのも難しいと思います。

今回の5000サンプルを抽出した分析からは、

  • 労働時間が知的好奇心に影響を与えると考えるのは難しい
  • 国の違いは知的好奇心に影響を与えていると考えてよさそう

という結論になります。

国の違いは知的好奇心に影響を与えており、また国の違いは労働時間に影響を与えているので知的好奇心と労働時間の間に相関があるように分析されてしまいましたが、少なくともこの結果からは労働時間が知的好奇心に影響を与えていると結論づけることはできないと思いました。

Stanによる分析は初めてなので何か怪しいことをしていたらご指摘お願いします。

心残り

  • 処理時間の関係から全データを用いた分析ができなかった
  • 奥村先生の記事で指摘されていたサンプルサイズが大きいために相関が小さいのに非常に有意と判定されてしまう問題は解消できなかった
  • PystanでやりたかったがNAの扱いがよく分からなくてできなかった

参考文献

StanとRでベイズ統計モデリング (Wonderful R)

StanとRでベイズ統計モデリング (Wonderful R)

以下はグラフのプロットに用いたコードです。

ms <- rstan::extract(fit)
N_mcmc <- length(ms$lp__)

nation <- c("ベルギー","カナダ","チリ","キプロス","チェコ","デンマーク","エストニア","フィンランド","フランス","ドイツ","ギリシャ","アイルランド","イスラエル","イタリア","日本","韓国","リトアニア","オランダ","ニュージーランド","ノルウェー","ポーランド","ロシア","シンガポール","スロバキア","スロベニア","スペイン","スウェーデン","トルコ","イギリス","アメリカ") # IDと国名の関係。(気をつけましたが、ミスがあるかもしれないです)
param_names <- c('mcmc', paste0('a_', nation), paste0('b_', nation))
d_est <- data.frame(1:N_mcmc, ms$a, ms$b)
colnames(d_est) <- param_names
data.frame.quantile.mcmc <- function(x, y_mcmc, probs=c(2.5, 25, 50, 75, 97.5)/100) {
   qua <- apply(y_mcmc, 2, quantile, probs=probs)
   d <- data.frame(X=x, t(qua))
   colnames(d) <- c('X', paste0('p', probs*100))
   return(d)
}
d_qua <- data.frame.quantile.mcmc(x=param_names[-1], y_mcmc=d_est[,-1])
d_melt <- reshape2::melt(d_est, id=c('mcmc'), variable.name='X')
d_melt$X <- factor(d_melt$X, levels=rev(levels(d_melt$X)))

p <- ggplot()
p <- p + theme_grey()
p <- p + coord_flip()
p <- p + geom_violin(data=d_melt[grep('a_', d_melt$X), ], aes(x=X, y=value))
p <- p + geom_pointrange(data=d_qua[grep('a_', d_qua$X), ], aes(x=X, y=p50, ymin=p2.5, ymax=p97.5), size=1)
p <- p + labs(x='切片', y='国名')
ggsave(file='graph1.png', plot=p)

p <- ggplot()
p <- p + theme_grey()
p <- p + coord_flip()
p <- p + geom_violin(data=d_melt[grep('b_', d_melt$X), ], aes(x=X, y=value))
p <- p + geom_pointrange(data=d_qua[grep('b_', d_qua$X), ], aes(x=X, y=p50, ymin=p2.5, ymax=p97.5), size=1)
p <- p + labs(x='傾き', y='国名')
ggsave(file='graph2.png', plot=p)

広告コーナー