package com.shreyansh.stransfer.renderscript_neuralnet;

import android.content.Context;
import android.support.v8.renderscript.Allocation;
import android.support.v8.renderscript.Element;
import android.support.v8.renderscript.RenderScript;
import android.util.Log;
import java.io.IOException;
import java.nio.FloatBuffer;

/* loaded from: classes.dex */
public class BatchNormalization extends NeuralNetLayerBase {
    private float[] avg_mean;
    private Allocation avg_mean_alloc;
    private float[] avg_var;
    private Allocation avg_var_alloc;
    private float[] beta;
    private Allocation beta_alloc;
    private float[] gamma;
    private Allocation gamma_alloc;
    private ScriptC_batchnormalization rs_BN;
    private int size;

    public BatchNormalization(Context context, RenderScript renderScript, int i) {
        super(context, renderScript);
        this.size = i;
        this.gamma = new float[i];
        this.beta = new float[i];
        this.avg_mean = new float[i];
        this.avg_var = new float[i];
        this.gamma_alloc = Allocation.createSized(this.mRS, Element.F32(this.mRS), i);
        this.beta_alloc = Allocation.createSized(this.mRS, Element.F32(this.mRS), i);
        this.avg_mean_alloc = Allocation.createSized(this.mRS, Element.F32(this.mRS), i);
        this.avg_var_alloc = Allocation.createSized(this.mRS, Element.F32(this.mRS), i);
        this.rs_BN = new ScriptC_batchnormalization(this.mRS);
        this.rs_BN.set_beta_alloc(this.beta_alloc);
        this.rs_BN.set_gamma_alloc(this.gamma_alloc);
        this.rs_BN.set_mean_alloc(this.avg_mean_alloc);
        this.rs_BN.set_var_alloc(this.avg_var_alloc);
        this.rs_BN.set_size(i);
    }

    @Override // com.shreyansh.stransfer.renderscript_neuralnet.NeuralNetLayerBase
    public void loadModel(String str) throws IOException {
        this.mInputStream = this.mContext.getAssets().open(str + "/gamma", 3);
        FloatBuffer.wrap(this.gamma).put(readInput(this.mInputStream).asFloatBuffer());
        this.gamma_alloc.copyFrom(this.gamma);
        this.mInputStream = this.mContext.getAssets().open(str + "/beta", 3);
        FloatBuffer.wrap(this.beta).put(readInput(this.mInputStream).asFloatBuffer());
        this.beta_alloc.copyFrom(this.beta);
        this.mInputStream = this.mContext.getAssets().open(str + "/avg_mean", 3);
        FloatBuffer.wrap(this.avg_mean).put(readInput(this.mInputStream).asFloatBuffer());
        this.avg_mean_alloc.copyFrom(this.avg_mean);
        this.mInputStream = this.mContext.getAssets().open(str + "/avg_var", 3);
        FloatBuffer.wrap(this.avg_var).put(readInput(this.mInputStream).asFloatBuffer());
        this.avg_var_alloc.copyFrom(this.avg_var);
        this.mInputStream.close();
        Log.v(NeuralNetLayerBase.TAG, "BatchNormalization loaded: " + this.gamma[0] + " " + this.beta[0] + " " + this.avg_var[0] + " " + this.avg_mean[0]);
    }

    public void process(Allocation allocation) {
        long currentTimeMillis = System.currentTimeMillis();
        this.rs_BN.forEach_process(allocation, allocation);
        this.mRS.finish();
        long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
        this.normalizeTime += currentTimeMillis2;
        Log.v(NeuralNetLayerBase.TAG, "BatchNormalization, size: " + this.size + " process time: " + currentTimeMillis2);
    }
}
