Programmieren in Rust

Memoisierung

Inhaltsverzeichnis

  1. Memoisierung
  2. Der Fixpunkt-Kombinator
  3. Memoisierender Fixpunkt-Kombiantor

Memoisierung

Die rekursive Definition der Fibonacci-Zahlen kann man ohne Umwege als naiven Algorithmus zu ihrer Berechnung nutzen.

fn fib(n: u32) -> u32 {
    if n < 2 {n} else {fib(n - 1) + fib(n - 2)}
}

fn main() {
    for n in 0..48 {
        println!("{:3} | {:10}", n, fib(n));
    }
}

Dieses Programm braucht ziemlich lange bis es durchgelaufen ist. Würden wir u32 gegen u64 ersetzen, käme es nicht bei 48 zum Überlauf, sondern erst bei 94. Die Berechnung von fib(93) würde allerdings ewig dauern. Genauer gesagt wächst der Berechnungsaufwand mit steigendem Argument exponentiell, wie sich bei der komplexitätstheoretischen Analyse herausstellt.

Eine systematische Technik zur Reduktion des Aufwandes ist die Memoisierung. Hierbei werden bereits berechnete Funktionswerte gespeichert, so dass sie danach direkt anstelle der Berechnung abgerufen werden können, wenn die gleiche Zahl abermals als Argument des Funktionsaufrufs vorkommt.

Die folgende Umsetzung nutzt ein Feld zur Speicherung der Funktionswerte.

fn memoized_fib() -> impl FnMut(u32) -> u32 {
    fn fib(n: u32, memo: &mut [u32]) -> u32 {
        let index = n as usize;
        if n != 0 && memo[index] == 0 {
            memo[index] = fib(n - 1, memo) + fib(n - 2, memo);
        }
        memo[index]
    }
    let mut memo = [0; 48];
    memo[1] = 1;
    move |n| fib(n, &mut memo)
}

fn main() {
    let mut fib = memoized_fib();
    for n in 0..48 {
        println!("{:3} | {:10}", n, fib(n));
    }
}

Alternativ kann man, im konkreten Fall mit Kanonen auf Spatzen geschossen, ein assoziatives Feld zur Speicherung einsetzen.

fn memoized_fib() -> impl FnMut(u32) -> u32 {
    use std::collections::HashMap;
    fn fib(n: u32, memo: &mut HashMap<u32, u32>) -> u32 {
        if let Some(&value) = memo.get(&n) {
            value
        } else { 
            let value = fib(n - 1, memo) + fib(n - 2, memo);
            memo.insert(n, value);
            value
        }
    }
    let mut memo = HashMap::from_iter([(0, 0), (1, 1)]);
    move |n| fib(n, &mut memo)
}

Der Fixpunkt-Kombinator

Worauf ich abziele, will ich zur Vereinfachung zunächst in Python darstellen. Ziel ist die Implementierung eines rekursiven Algorithmus. Beispielsweise berechnet die Funktion

def fac(n):
    return 1 if n == 0 else n*fac(n - 1)

bekanntlich die Fakultät einer Zahl. Nun kann man Rekursion alternativ mit einem Hilfsmittel herstellen, dem sogenannten Fixpunkt-Kombinator. Das geht so:

def fix(cb):
    def f(x): return cb(f, x)
    return f

fac = fix(lambda f, n: 1 if n == 0 else n*f(n - 1))

Wir machen es uns nun zur Aufgabe, fix in Rust zu formulieren. Dabei stellt sich das Problem in den Weg, dass man aufgrund der eingefangenen Variable cb (callback) ein Closure benötigt, Rust jedoch keine benannten Closures kennt. Benannte Closures sind allerdings entbehrlich. Mit der folgenden Trickserei kommen wir da raus:

def fix(cb):
    return lambda x: cb(cb, x)

fac = fix(lambda f, n: 1 if n == 0 else n*f(f, n - 1))

Schließlich müssen wir noch das Typsystem befriedigen. Bei der näheren Betrachtung stellt sich heraus, dass f einen rekursiven Typ haben muss, denn f bekommt sich selbst als Argument zugeführt. Das stellt aber kein wesentliches Problem dar, denn rekursive Typen sind mittels struct konstruierbar. Insgesamt findet man die folgende Konstruktion.

struct Rec<'a>(&'a dyn Fn(&Rec, u32) -> u32);

impl Rec<'_> {
    fn call(&self, x: u32) -> u32 {(self.0)(&self, x)}
}

fn fix(cb: impl Fn(&Rec, u32) -> u32) -> impl Fn(u32) -> u32 {
    move |x| cb(&Rec(&cb), x)
}

fn main() {
    let fac = fix(|f, n| if n == 0 {1} else {n*f.call(n - 1)});
    println!("{}", fac(4));
}

Einer generischen Formulierung steht keine Beschwerlichkeit im Wege. Wir gelangen zu:

struct Rec<'a, X, Y>(&'a dyn Fn(&Rec<X, Y>, X) -> Y);

impl<X, Y> Rec<'_, X, Y> {
    fn call(&self, x: X) -> Y {(self.0)(&self, x)}
}

fn fix<X, Y>(cb: impl Fn(&Rec<X, Y>, X) -> Y) -> impl Fn(X) -> Y {
    move |x| cb(&Rec(&cb), x)
}

Damit haben wir nun einen allgemeinen Fixpunkt-Kombinator. Mehr noch, es sind sogar mehrstellige Funktionen realisierbar. Darf man nämlich für X alle möglichen Typen einsetzen, dann darf es auch ein Tupeltyp sein. Beispielsweise kann man die Potenzfunktion folgendermaßen schreiben:

fn main() {
    let pow = fix(|f, (x, n)|
        if n == 0 {1} else {x*f.call((x, n - 1))});
    println!("{}", pow((2, 10)));
}

Wie das Muster

let even = fix(|e, n| {
    let odd = fix(|_, n| if n == 0 {false} else {e.call(n - 1)});
    if n == 0 {true} else {odd(n - 1)}
});

zeigt, kann man mit fix fernerhin wechselseitige Rekursion basteln.

Memoisierender Fixpunkt-Kombinator

Als Nächstes erfolgt der Einbau der Memoisierung in den Fixpunkt-Kombinator. Zuvor gebe ich wieder eine kurze Verdeutlichung in Python, worin das Ziel besteht.

def fix(cb):
    memo = {}
    def f(x):
        if x not in memo: memo[x] = cb(f, x)
        return memo[x]
    return f

fib = fix(lambda f, n: n if n < 2 else f(n - 1) + f(n - 2))

Die Memoisierung muss in der Operation call erfolgen, weil gerade dieser Aufruf bei der Rekursion wieder und wieder durchlaufen wird. Als Speicher diene ebenfalls ein assoziatives Feld, denn auf diese Weise verbleibt die Funktionalität recht allgemein.

use std::{hash::Hash, collections::HashMap};

struct Rec<'a, X, Y> {
    cb: &'a dyn Fn(&mut Rec<X, Y>, X) -> Y,
    memo: &'a mut HashMap<X, Y>
}

impl<X, Y> Rec<'_, X, Y> where X: Eq + Hash + Clone, Y: Clone {
    fn call(&mut self, x: X) -> Y {
        if let Some(value) = self.memo.get(&x) {
            value.clone()
        } else {
            let value = (self.cb)(self, x.clone());
            self.memo.insert(x, value.clone());
            value
        }
    }
}

fn fix<X, Y>(cb: impl Fn(&mut Rec<X, Y>, X) -> Y)
-> impl FnMut(X) -> Y
where X: Eq + Hash + Clone, Y: Clone
{
    let mut memo = HashMap::new();
    move |x| cb(&mut Rec {cb: &cb, memo: &mut memo}, x)
}

fn main() {
    let mut fib = fix(|f, n: u32|
        if n < 2 {n} else {f.call(n - 1) + f.call(n - 2)});

    for n in 0..48 {
        println!("{:3} | {:10}", n, fib(n));
    }
}