package base.evaluate;

import java.util.HashMap;
import java.util.List;

public class ConfusionMatrix {
    private HashMap<String, Integer> labelCntMap = new HashMap<>();

    private String combineLabel(String trueLabel, String predictLabel) {
        return trueLabel + "\t" + predictLabel;
    }

    private void sysPrint(String value) {
        System.out.print(value);
    }

    public void add(int trueLabel, int predictLabel) {
        add(String.valueOf(trueLabel), String.valueOf(predictLabel));
    }

    public void add(String trueLabel, String predictLabel) {
        String tp = combineLabel(trueLabel, predictLabel);
        labelCntMap.put(tp, labelCntMap.getOrDefault(tp, 0) + 1);
    }

    public void printResult(List<String> labels) {
        sysPrint("\t");
        for (String predictLabel: labels) {
            sysPrint("\t" + predictLabel);
        }
        sysPrint("\r\n");
        for (String trueLabel: labels) {
            sysPrint(trueLabel);
            for (String predictLabel: labels) {
                String tp = combineLabel(trueLabel, predictLabel);
                int cnt = labelCntMap.getOrDefault(tp, 0);
                sysPrint("\t"+cnt);
            }
            sysPrint("\r\n");
        }

        for (String label: labels) {
            double precision = .0;
            double recall = 0.;
            int pu = labelCntMap.getOrDefault(combineLabel(label, label), 0);
            int ru = pu;
            int pd = 0;
            int rd = 0;
            for (String otherLabel: labels) {
                pd += labelCntMap.getOrDefault(combineLabel(otherLabel, label), 0);
                rd += labelCntMap.getOrDefault(combineLabel(label, otherLabel), 0);
            }
            precision = Double.valueOf(String.format("%.2f", pu * 100 / Double.valueOf(pd)));
            recall = Double.valueOf(String.format("%.2f", ru * 100 / Double.valueOf(rd)));
            double f1 = Double.valueOf(String.format("%.2f", 2 * precision * recall / (precision + recall)));
            sysPrint(label + "\t" + precision + "%\t" + recall + "%\t" + f1 + "%\r\n");
        }
    }
}