/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.tree;

import hex.genmodel.PredictContributions;
import hex.genmodel.algos.tree.ContributionComposer;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.genmodel.attributes.parameters.FeatureContribution;
import hex.genmodel.utils.ArrayUtils;

public abstract class ContributionsPredictor<E>
implements PredictContributions {
    private final int _ncontribs;
    private final String[] _contribution_names;
    private final TreeSHAPPredictor<E> _treeSHAPPredictor;
    private final int _workspaceSize;
    private static final ThreadLocal<TreeSHAPPredictor.Workspace> _workspace = new ThreadLocal();

    public ContributionsPredictor(int ncontribs, String[] featureContributionNames, TreeSHAPPredictor<E> treeSHAPPredictor) {
        this._ncontribs = ncontribs;
        this._contribution_names = ArrayUtils.append(featureContributionNames, "BiasTerm");
        this._treeSHAPPredictor = treeSHAPPredictor;
        this._workspaceSize = this._treeSHAPPredictor.getWorkspaceSize();
    }

    @Override
    public final String[] getContributionNames() {
        return this._contribution_names;
    }

    @Override
    public final float[] calculateContributions(double[] input) {
        float[] contribs = new float[this._ncontribs];
        this._treeSHAPPredictor.calculateContributions(this.toInputRow(input), contribs, 0, -1, this.getWorkspace());
        return this.getContribs(contribs);
    }

    protected abstract E toInputRow(double[] var1);

    public float[] getContribs(float[] contribs) {
        return contribs;
    }

    private TreeSHAPPredictor.Workspace getWorkspace() {
        TreeSHAPPredictor.Workspace workspace = _workspace.get();
        if (workspace == null || workspace.getSize() != this._workspaceSize) {
            workspace = this._treeSHAPPredictor.makeWorkspace();
            assert (workspace.getSize() == this._workspaceSize);
            _workspace.set(workspace);
        }
        return workspace;
    }

    @Override
    public FeatureContribution[] calculateContributions(double[] input, int topN, int bottomN, boolean compareAbs) {
        float[] contributions = this.calculateContributions(input);
        int[] contributionNameIds = ArrayUtils.range(0, this._contribution_names.length - 1);
        int[] sorted = new ContributionComposer().composeContributions(contributionNameIds, contributions, topN, bottomN, compareAbs);
        FeatureContribution[] out = new FeatureContribution[sorted.length];
        for (int i2 = 0; i2 < sorted.length; ++i2) {
            out[i2] = new FeatureContribution(this._contribution_names[sorted[i2]], contributions[sorted[i2]]);
        }
        return out;
    }
}

