私自身の混合多項分布の理解度はまだいまいちな気もするけど、とりあえず、ここにあったPythonのコードをJavaで書き直してみた。
public class MultinomialMixture { double[][] dataset; int num; int dim; int mNum; double[][] c; double[][] q; public MultinomialMixture(double[][] dataset, int k) { this.dataset = dataset; num = dataset.length; dim = dataset[0].length; mNum = k; q = new double[mNum][dim]; c = new double[num][mNum]; Random random = new Random(1); for (int i = 0; i < mNum; i++) { for (int j = 0; j < dim; j++) { q[i][j] = random.nextDouble(); } q[i] = normalize(q[i]); } } public void execute(int loop) { for (int i = 0; i < loop; i++) { stepE(); stepM(); } } double[] normalize(double[] value) { double[] ret = new double[value.length]; double sum = 0; for (int i = 0; i < value.length; i++) { sum += value[i]; } for (int i = 0; i < value.length; i++) { ret[i] = value[i] / sum; } return ret; } double multi(double[] u, double[] x) { // return prod([q[w] ** d[w] for w in range(W)]) double value = 1; for (int i = 0; i < u.length; i++) { value *= Math.pow(u[i], x[i]); } return value; } void stepE() { // for n in range(N): // C[n] = normalize([multi(D[n], Q[k]) for k in range(K)]) for (int i = 0; i < num; i++) { double[] dn = dataset[i]; for (int j = 0; j < mNum; j++) { c[i][j] = multi(q[j], dn); } c[i] = normalize(c[i]); } } void stepM() { // for k in range(K): // Q[k] = normalize([sum([C[n][k] * D[n][w] for n in range(N)]) for w in range(W)]) for (int i = 0; i < mNum; i++) { double[] value = new double[dim]; for (int j = 0; j < dim; j++) { value[j] = 0; for (int k = 0; k < num; k++) { value[j] += c[k][i] * dataset[k][j]; } } value = normalize(value); q[i] = value; } } double logLikelihood() { // L = 0 // for n in range(N): // p = [C[n][k] * multinomial(D[n], Q[k]) for k in range(K)] // L += log(sum(p)) double l = 0; for (int i = 0; i < num; i++) { double sum = 0; for (int j = 0; j < mNum; j++) { sum += c[i][j] * multinomial(dataset[i], q[j]); } l += Math.log(sum); } return l; } double multinomial(double[] d, double[] q) { // return factorial(sum(d)) / prod([factorial(d[w]) for w in range(W)]) * multi(d,q) double prod = 1; for (int i = 0; i < dim; i++) { prod *= factorial(d[i]); } return factorial(sum(d)) / prod * multi(q, d); } double factorial(double x) { // if x == 0: return 1 // return reduce(mul, xrange(1, x+1)) double prod = 1; for (int i = 1; i < x + 1; i++) { prod *= i; } return prod; } double sum(double[] v) { double sum = 0; for (int i = 0; i < v.length; i++) { sum += v[i]; } return sum; } public void print() { System.out.println("L=" + logLikelihood()); // c = new double[num][mNum]; StringBuilder buf = new StringBuilder(); buf.append("C=["); for (int i = 0; i < num; i++) { buf.append('['); for (int j = 0; j < mNum; j++) { buf.append(c[i][j]); if (j != mNum - 1) { buf.append(','); } } buf.append(']'); if (i != num - 1) { buf.append(','); } } buf.append(']'); System.out.println(buf.toString()); // q = new double[mNum][dim]; buf = new StringBuilder(); buf.append("Q=["); for (int i = 0; i < mNum; i++) { buf.append('['); for (int j = 0; j < dim; j++) { buf.append(q[i][j]); if (j != dim - 1) { buf.append(','); } } buf.append(']'); if (i != mNum - 1) { buf.append(','); } } buf.append(']'); System.out.println(buf.toString()); } }
使うときは以下の感じ。
int k = 2; MultinomialMixture mm = new MultinomialMixture(dataset, k); mm.execute(10); mm.print();
かなり勢いで書いたから、細かいことは後で直すとして、Pythonのやつと同じ感じの結果だったからとりあえず、よしとする。