/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.lops.MapMult;
import org.apache.sysds.lops.PMMJ;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.BinaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateBinarySPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class MMFEDInstruction
extends BinaryFEDInstruction {
    private MMFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(FEDInstruction.FEDType.MAPMM, op, in1, in2, out, opcode, istr);
    }

    public static MMFEDInstruction parseInstruction(AggregateBinarySPInstruction instr) {
        return new MMFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.output, instr.getOpcode(), instr.getInstructionString());
    }

    public static MMFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!ArrayUtils.contains((Object[])new String[]{MapMult.OPCODE, PMMJ.OPCODE, "cpmm", "rmm"}, (Object)opcode)) {
            throw new DMLRuntimeException("MapmmSPInstruction.parseInstruction():: Unknown opcode " + opcode);
        }
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand out = new CPOperand(parts[3]);
        AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
        return new MMFEDInstruction(aggbin, in1, in2, out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        MatrixLineagePair mo1 = ec.getMatrixLineagePair(this.input1);
        MatrixLineagePair mo2 = ec.getMatrixLineagePair(this.input2);
        long id = FederationUtils.getNextFedDataID();
        FederatedRequest frEmpty = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new Object[]{new MatrixCharacteristics(-1L, -1L), Types.DataType.MATRIX});
        if (mo1.isFederated(FTypes.FType.COL) && mo2.isFederated(FTypes.FType.ROW) && mo1.getFedMapping().isAligned(mo2.getFedMapping(), FTypes.AlignType.COL_T)) {
            FederatedRequest fr1 = FederationUtils.callInstruction(this.instString, this.output, id, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}, Types.ExecType.SPARK, false);
            if (this._fedOut.isForcedFederated()) {
                mo1.getFedMapping().execute(this.getTID(), frEmpty, fr1);
                this.setPartialOutput(mo1.getFedMapping(), mo1.getMO(), mo2.getMO(), fr1.getID(), ec);
            } else {
                this.aggregateLocally(mo1.getFedMapping(), true, ec, frEmpty, fr1);
            }
        } else if (mo1.isFederated(FTypes.FType.ROW) || mo1.isFederated(FTypes.FType.PART)) {
            boolean isPartOut;
            FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
            FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, id, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()}, Types.ExecType.SPARK, false);
            boolean isVector = mo2.getNumColumns() == 1L;
            boolean bl = isPartOut = mo1.isFederated(FTypes.FType.PART) || !isVector && mo2.isFederated(FTypes.FType.PART);
            if (isPartOut && this._fedOut.isForcedFederated()) {
                mo1.getFedMapping().execute(this.getTID(), true, frEmpty, fr1, fr2);
                this.setPartialOutput(mo1.getFedMapping(), mo1.getMO(), mo2.getMO(), fr2.getID(), ec);
            } else if ((this._fedOut.isForcedFederated() || !isVector && !this._fedOut.isForcedLocal()) && !isPartOut) {
                mo1.getFedMapping().execute(this.getTID(), true, frEmpty, fr1, fr2);
                this.setOutputFedMapping(mo1.getFedMapping(), mo1.getMO(), mo2.getMO(), fr2.getID(), ec);
            } else {
                this.aggregateLocally(mo1.getFedMapping(), mo1.isFederated(FTypes.FType.PART), ec, frEmpty, fr1, fr2);
            }
        } else if (mo2.isFederated(FTypes.FType.ROW)) {
            FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, true);
            FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, id, new CPOperand[]{this.input1, this.input2}, new long[]{fr1[0].getID(), mo2.getFedMapping().getID()}, Types.ExecType.SPARK, false);
            if (this._fedOut.isForcedFederated()) {
                mo2.getFedMapping().execute(this.getTID(), true, fr1, new FederatedRequest[]{frEmpty, fr2});
                this.setPartialOutput(mo2.getFedMapping(), mo1.getMO(), mo2.getMO(), fr2.getID(), ec);
            } else {
                this.aggregateLocally(mo2.getFedMapping(), true, ec, fr1, new FederatedRequest[]{frEmpty, fr2});
            }
        } else if (mo1.isFederated(FTypes.FType.COL)) {
            FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, true);
            FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, id, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), fr1[0].getID()}, Types.ExecType.SPARK, false);
            if (this._fedOut.isForcedFederated()) {
                mo1.getFedMapping().execute(this.getTID(), true, fr1, new FederatedRequest[]{frEmpty, fr2});
                this.setPartialOutput(mo1.getFedMapping(), mo1.getMO(), mo2.getMO(), fr2.getID(), ec);
            } else {
                this.aggregateLocally(mo1.getFedMapping(), true, ec, fr1, new FederatedRequest[]{frEmpty, fr2});
            }
        } else {
            throw new DMLRuntimeException("Federated AggregateBinary not supported with the following federated objects: " + mo1.isFederated() + ":" + mo1.getFedMapping() + " " + mo2.isFederated() + ":" + mo2.getFedMapping());
        }
    }

    private void setPartialOutput(FederationMap federationMap, MatrixObject mo1, MatrixObject mo2, long outputID, ExecutionContext ec) {
        MatrixObject out = ec.getMatrixObject(this.output);
        out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), mo1.getBlocksize());
        FederationMap outputFedMap = federationMap.copyWithNewIDAndRange(mo1.getNumRows(), mo2.getNumColumns(), outputID);
        out.setFedMapping(outputFedMap);
    }

    private void setOutputFedMapping(FederationMap federationMap, MatrixObject mo1, MatrixObject mo2, long outputID, ExecutionContext ec) {
        MatrixObject out = ec.getMatrixObject(this.output);
        out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), mo1.getBlocksize());
        out.setFedMapping(federationMap.copyWithNewID(outputID, mo2.getNumColumns()));
    }

    private void aggregateLocally(FederationMap fedMap, boolean aggAdd, ExecutionContext ec, FederatedRequest ... fr) {
        this.aggregateLocally(fedMap, aggAdd, ec, (FederatedRequest[])null, fr);
    }

    private void aggregateLocally(FederationMap fedMap, boolean aggAdd, ExecutionContext ec, FederatedRequest[] frSliced, FederatedRequest ... fr) {
        long callInstID = fr[fr.length - 1].getID();
        FederatedRequest frG = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstID);
        FederatedRequest frC = fedMap.cleanup(this.getTID(), callInstID);
        Future<FederatedResponse>[] ffr = frSliced != null ? fedMap.execute(this.getTID(), frSliced, (FederatedRequest[])ArrayUtils.addAll((Object[])fr, (Object[])new FederatedRequest[]{frG, frC})) : fedMap.execute(this.getTID(), (FederatedRequest[])ArrayUtils.addAll((Object[])fr, (Object[])new FederatedRequest[]{frG, frC}));
        MatrixBlock ret = aggAdd ? FederationUtils.aggAdd(ffr) : FederationUtils.bind(ffr, false);
        ec.setMatrixOutput(this.output.getName(), ret);
    }
}

