/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysds.test.functions.federated.primitives.part3;

import java.util.Arrays;
import java.util.Collection;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedTransferLocalDataTest extends AutomatedTestBase {
	private final static String TEST_DIR = "functions/federated/";
	private final static String TEST_NAME1 = "FederatedTransferLocalDataTest";
	private final static String TEST_CLASS_DIR = TEST_DIR + FederatedTransferLocalDataTest.class.getSimpleName() + "/";

	private final static int blocksize = 1024;
	@Parameterized.Parameter()
	public int rows;
	@Parameterized.Parameter(1)
	public int cols;
	@Parameterized.Parameter(2)
	public boolean rowPartitioned;

	@Override
	public void setUp() {
		TestUtils.clearAssertionInformation();
		addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
	}

	@Parameterized.Parameters
	public static Collection<Object[]> data() {
		return Arrays.asList(new Object[][] {{12, 4, true}, {12, 4, false},});
	}

	@Test
	public void federatedTransferCP() {
		runTransferTest(Types.ExecMode.SINGLE_NODE);
	}

	@Test
	public void federatedTransferSP() {
		runTransferTest(Types.ExecMode.SPARK);
	}

	private void runTransferTest(Types.ExecMode execMode) {
		String TEST_NAME = TEST_NAME1;
		ExecMode platformOld = setExecMode(execMode);

		getAndLoadTestConfiguration(TEST_NAME);
		String HOME = SCRIPT_DIR + TEST_DIR;

		// write input matrices
		double[][] X = getRandomMatrix(rows, cols, 1, 5, 1, 3);

		MatrixCharacteristics mc = new MatrixCharacteristics(rows, cols, blocksize, (long) rows * cols);
		writeInputMatrixWithMTD("X", X, false, mc);

		// empty script name because we don't execute any script, just start the worker
		fullDMLScriptName = "";
		int port1 = getRandomAvailablePort();
		int port2 = getRandomAvailablePort();
		int port3 = getRandomAvailablePort();
		int port4 = getRandomAvailablePort();
		Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
		Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
		Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
		Process t4 = startLocalFedWorker(port4);

		try {
			if(!isAlive(t1, t2, t3, t4))
				throw new RuntimeException("Failed starting federated worker");

			rtplatform = execMode;
			if(rtplatform == Types.ExecMode.SPARK) {
				DMLScript.USE_LOCAL_SPARK_CONFIG = true;
			}
			TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
			loadTestConfiguration(config);

			// Run reference dml script with normal matrix
			fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
			programArgs = new String[] {"-stats", "100", "-args", input("X"), expected("S")};

			runTest(null);

			fullDMLScriptName = HOME + TEST_NAME + ".dml";
			programArgs = new String[] {"-stats", "100", "-nvargs", "in_X=" + input("X"),
				"in_X1=" + TestUtils.federatedAddressNoInput("localhost", port1),
				"in_X2=" + TestUtils.federatedAddressNoInput("localhost", port2),
				"in_X3=" + TestUtils.federatedAddressNoInput("localhost", port3),
				"in_X4=" + TestUtils.federatedAddressNoInput("localhost", port4), "rows=" + rows, "cols=" + cols,
				"rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};

			runTest(null);

			// compare via files
			compareResults(1e-9, "Stat-DML1", "Stat-DML2");

		}
		finally {
			TestUtils.shutdownThreads(t1, t2, t3, t4);

			resetExecMode(platformOld);
		}
	}
}
