RustによるNelder-Mead法の実装
Rustの練習その2としてNelder-Nead法を実装しました。
Nelder-Mead法のアルゴリズムには様々な変種がありますが、ここでは以下の本に書いてあるものを使いました。

- 作者: Charles Audet,Warren Hare
- 出版社/メーカー: Springer
- 発売日: 2017/12/13
- メディア: ハードカバー
- この商品を含むブログを見る
また、最適化の対象にはTest functions for optimization - Wikipediaからいくつか選びました。
実装
行列演算ライブラリとしてrust-ndarrayを使うためにCargo.toml
に
[dependencies] ndarray = "0.11.0"
を追記しました。
main.rs
は以下の通りです。
extern crate ndarray; use ndarray::*; use std::f64; use std::f64::consts; fn booth(a: &Array1<f64>) -> f64 { if a.len() != 2 { panic!("input dimension must be 2."); } let x = a[0]; let y = a[1]; (x+2.0*y-7.0).powf(2.0) + (2.0*x+y-5.0).powf(2.0) } fn himmelblau(a: &Array1<f64>) -> f64 { if a.len() != 2 { panic!("input dimension must be 2."); } let x = a[0]; let y = a[1]; (x*x+y-11.0).powf(2.0) + (x+y*y-7.0).powf(2.0) } fn ackley(a: &Array1<f64>) -> f64 { if a.len() != 2 { panic!("input dimension must be 2."); } let x = a[0]; let y = a[1]; let pi = consts::PI; -20.0 * (-0.2 * (0.5 * (x*x + y*y)).sqrt()).exp() - (0.5*(2.0*pi*x).cos()+(2.0*pi*y).cos()).exp() + 1.0_f64.exp() + 20.0 } fn nelder_mead<T>(f: T, x_init: &Array1<f64>) -> (f64, Array1<f64>) where T: Fn(&Array1<f64>) -> f64 { // dimension let n = x_init.len(); // constant let delta_e = 2.0; let delta_oc = 0.5; let delta_ic = -0.5; let gamma = 0.5; // initial simplex // https://github.com/scipy/scipy/blob/master/scipy/optimize/optimize.py let nonzdelt = 0.05; let zdelt = 0.00025; let mut y: Vec<(f64, Array1<f64>)> = Vec::new(); y.push((f(&x_init), x_init.clone())); for k in 0..n { let mut y_k = x_init.clone(); if y_k[k] != 0.0 { y_k[k] = (1.0 + nonzdelt) * y_k[k]; } else { y_k[k] = zdelt; } y.push((f(&y_k), y_k)); } for _ in 0..1000 { y.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); let f_best = y[0].0; // Reflect let mut _x_c = y[0].1.clone(); for i in 1..n { _x_c += &y[i].1; } let x_c = &_x_c / (n as f64); let x_r = &x_c + &(&x_c - &y[n].1); let f_r = f(&x_r); if f_best <= f_r && f_r < y[n-1].0 { y[n] = (f_r, x_r); continue; } // Expand if f_r < f_best { let x_e = &x_c + &(delta_e * (&x_c - &y[n].1)); let f_e = f(&x_e); if f_e < f_r { y[n] = (f_e, x_e); continue; } else { y[n] = (f_r, x_r); continue; } } // Outside Contraction if y[n-1].0 <= f_r && f_r < y[n].0 { let x_oc = &x_c + &(delta_oc * (&x_c - &y[n].1)); let f_oc = f(&x_oc); if f_oc < f_r { y[n] = (f_oc, x_oc); continue; } else { y[n] = (f_r, x_r); continue; } } // Inside Contraction if y[n].0 <= f_r { let x_ic = &x_c + &(delta_ic * (&x_c - &y[n].1)); let f_ic = f(&x_ic); if f_ic < y[n].0 { y[n] = (f_ic, x_ic); continue; } } // Shrink for i in 1..n+1 { let y_i = &y[0].1 + &(gamma * (&y[i].1 - &y[0].1)); let f_i = f(&y_i); y[i] = (f_i, y_i); } } y.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); (y[0].0, y[1].1.clone()) } fn main() { let y = array![3.0, 3.0]; let (min, v) = nelder_mead(&booth, &y); println!("a minimum of Booth: {} when (x, y) = ({})", min, v); let (min, v) = nelder_mead(&himmelblau, &y); println!("a minimum of Himmelblau: {} when (x, y) = ({})", min, v); let (min, v) = nelder_mead(&ackley, &y); println!("a minimum of Ackley: {} when (x, y) = ({})", min, v); }
結果は、
a minimum of Booth: 0 when (x, y) = ([1.0000000000000004, 3]) a minimum of Himmelblau: 0 when (x, y) = ([3, 2]) a minimum of Ackley: 7.250114821582645 when (x, y) = ([2.987540988158525, 2.9937622971963957])
となり、Booth関数とHimmelblau関数については大域最小値(の一つ)を発見できました。
Ackley関数ではうまく動作しませんでしたが、これは関数の形状的に仕方がないでしょう。
以前のバージョン
この記事の以前のバージョンでは英語版Wikipediaに書いてあるアルゴリズムを参考にしたコードを載せていました。
当時ハマったところとその解決編を合わせて残しておきます。
extern crate ndarray; use ndarray::*; fn booth(x: &Array1<f64>) -> f64 { (x[0]+2.0*x[1]-7.0).powf(2.0) + (2.0*x[0]+x[1]-5.0).powf(2.0) } fn nelder_mead<T>(f: T) where T: Fn(&Array1<f64>) -> f64 { // constant let alpha = 1.0; let gamma = 2.0; let rho = 0.5; let sigma = 0.5; // initial simplex let x0 = array![0.0, 0.0]; let x1 = array![5.0, 0.0]; let x2 = array![0.0, 5.0]; let mut x = vec![(f(&x0), x0), (f(&x1), x1), (f(&x2), x2)]; for _ in 0..100 { x.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); // ハマりポイントその1 println!("best value, vector: {}, {}", x[0].0, x[0].1); // centroid let xo = (&x[0].1 + &x[1].1) / 2.0; // ハマりポイントその2a // Reflection let xr = &xo - &(alpha * (&xo - &x[2].1)); // ハマりポイントその2b let f_xr = f(&xr); if x[0].0 <= f_xr && f_xr < x[1].0 { x[2] = (f_xr, xr); continue; } // Expansion if f_xr < x[0].0 { let xe = &xo + &(gamma * (&xr - &xo)); let f_xe = f(&xe); if f_xe < f_xr { x[2] = (f_xe, xe); continue; } else { x[2] = (f_xr, xr); continue; } } // Contraction let xc = &xo + &(rho * (&x[2].1 - &xo)); let f_xc = f(&xc); if f_xc < x[2].0 { x[2] = (f_xc, xc); continue; } // Shrink for i in 1..3 { let xi = &x[0].1 + &(sigma * (&x[i].1 - &x[0].1)); let f_xi = f(&xi); x[i] = (f_xi, xi); } } } fn main() { nelder_mead(&booth); }
結果
cargo run
すると
best value, vector: 9, [0, 5] best value, vector: 9, [0, 5] best value, vector: 5.9140625, [2.8125, 1.5625] best value, vector: 1.30712890625, [1.328125, 2.265625] (中略) best value, vector: 0.1797569358295823, [1.3125408096108064, 2.7780022306583767] best value, vector: 0.1797569358295823, [1.3125408096108064, 2.7780022306583767] best value, vector: 0.1797569358295823, [1.3125408096108064, 2.7780022306583767]
となりました。Booth functionはf(1, 3)=0
が最小値なので微妙です。
ハマりポイント(コード参照)
その1
Nelder-Mead法の実装には関数値でソートする必要があるのですが、f64型に対してはsort()
を使うことができませんでした。
How to sort a Vec of floats? - help - The Rust Programming Language Forumを見てこう書きました。
追記: ordered-floatというのがあるらしいです。
その2
最初は2aを
let xo = (&x[0].1 + &x[1].1) / 2.0;
と書き、2bを
let xr = &xo - alpha * (&xo - &x[2].1);
と書いていたのですが、前者はコンパイルを通り、後者はコンパイルが通りませんでした。
エラーメッセージを読むと期待した通りの型になっていないっぽいのですが、どうしてなのかは分かりませんでした。
その後
let xr = xo.clone() - alpha * (xo.clone() - x[2].1.clone());
としたらコンパイルが通るようになりました。
「clone()すると複製コストがかかるような気がする(気のせいかもしれない)ので、cloneしない方法があるなら知りたいと思っています」と書いたところ、有識者からご指摘をいただき今のコードになりました。
ndarray も rust も分からないので雑なリプライになるけど ```let xr = &xo - &(alpha * (&xo - &x[2].1));``` とするとコンパイルは通ると思う。https://t.co/uTfTxHFJ4R によると &Array1 + Array1 という形の演算はサポートされていなさそう。
— 再炎上 (@mofmoffox) 2018年2月11日
もうmofmoffoxが書いてるけど左辺を&にするとコンパイルが通ります。参照同士の演算子オーバーロードはdelegateされておらず、型ごとに複数実装する仕様になっているのでライブラリごとに確認が必要で、プリミティブは全ての組み合わせを実装していますがndarrayは実装していないようです。
— Masaki=就寝=Hara (@qnighy) 2018年2月11日