/* CreateSkeleton
 * Creates a skeleton PlayStation dev project based on the output of SymDumpTE
 * or another source of the appropriate JSON data
 * Copyright 2019 Ben Lincoln
 * https://www.beneaththewaves.net/
 * 
 * This file is part of CreateSkeleton.
 * 
 * CreateSkeleton is free software: you can redistribute it and/or modify
 * it under the terms of version 3 of the GNU General Public License as published by
 * the Free Software Foundation.
 * 
 * CreateSkeleton is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with CreateSkeleton (in the file LICENSE.txt).  
 * If not, see <http://www.gnu.org/licenses/>.
 */
// %PROJECT_NAME%: Machine-generated function-defining script

import ghidra.app.script.GhidraScript;
import ghidra.app.services.DataTypeManagerService;
import ghidra.app.util.importer.MemoryConflictHandler;
import ghidra.app.util.MemoryBlockUtil;
import ghidra.framework.plugintool.PluginTool;
import ghidra.program.flatapi.FlatProgramAPI;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.*;
import ghidra.program.model.block.CodeBlock;
import ghidra.program.model.block.PartitionCodeSubModel;
import ghidra.program.model.data.*;
import ghidra.program.model.listing.*;
import ghidra.program.model.lang.Register;
import ghidra.program.model.lang.RegisterManager;
import ghidra.program.model.lang.RegisterValue;
import ghidra.program.model.mem.Memory;
import ghidra.program.model.mem.*;
import ghidra.program.model.symbol.RefType;
import ghidra.program.model.symbol.*;
import ghidra.util.exception.InvalidInputException;

import java.io.InputStream;
import java.io.StringWriter;
import java.io.PrintWriter;
import java.math.BigInteger;
import java.util.HashMap; 
import java.util.*;

// Ben Lincoln, 2019

public class %PROJECT_NAME%TDRDefineFunctions extends GhidraScript
{
	public class FunctionParameter
	{
		public String ParameterName;
		public DataType ParameterDataType;
		public String RegisterName;
		public int StackOffset;
		public boolean IsRegisterParam;
		
		public FunctionParameter(String name, DataType dataType, String registerName, int stackOffset, boolean isRegisterParam)
		{
			ParameterName = name;
			ParameterDataType = dataType;
			RegisterName = registerName;
			StackOffset = stackOffset;
			IsRegisterParam = isRegisterParam;
		}
	}

	private int numTried;
	private int numCreated;
	SymbolTable symbolTable;

	@Override
	public void run() throws Exception
	{
		// set stack pointer and global pointer
		ProgramContext ctx = currentProgram.getProgramContext();

		// get most data types
		HashMap<String, DataType> dataTypeMap = GetDataTypeMap();

		// Import functions without function pointer parameters (begin)
%FUNCTION_IMPORT_SUBROUTINE_CALLS_1%

		// Import functions without function pointer parameters (end)

		// Import functions with function pointer parameters (begin)
		dataTypeMap = importFunctionPointers(dataTypeMap);
%FUNCTIONS_WITH_FUNCTION_POINTER_PARAMETERS_CALLS%

		// Import functions with function pointer parameters (end)

	}

	protected HashMap<String, DataType> GetDataTypeMap()
	{
		HashMap<String, DataType> dataTypeMap = new HashMap<>();

		// C primitives
		dataTypeMap.put("bool", new BooleanDataType());
		dataTypeMap.put("double", new DoubleDataType());
		dataTypeMap.put("char", new CharDataType());
		dataTypeMap.put("dword", new DWordDataType());
		dataTypeMap.put("float", new FloatDataType());
		dataTypeMap.put("int", new IntegerDataType());
		dataTypeMap.put("long", new LongDataType());
		dataTypeMap.put("longlong", new LongLongDataType());
		dataTypeMap.put("long long", new LongLongDataType());
		dataTypeMap.put("qword", new QWordDataType());
		dataTypeMap.put("short", new ShortDataType());
		dataTypeMap.put("string", new StringDataType());
		dataTypeMap.put("uchar", new UnsignedCharDataType());
		dataTypeMap.put("uint", new UnsignedIntegerDataType());
		dataTypeMap.put("ulong", new UnsignedLongDataType());
		dataTypeMap.put("ulonglong", new UnsignedLongLongDataType());
		dataTypeMap.put("unsigned char", new UnsignedCharDataType());
		dataTypeMap.put("unsigned int", new UnsignedIntegerDataType());
		dataTypeMap.put("unsigned long", new UnsignedLongDataType());
		dataTypeMap.put("unsigned longlong", new UnsignedLongLongDataType());
		dataTypeMap.put("unsigned long long", new UnsignedLongLongDataType());
		dataTypeMap.put("unsigned short", new UnsignedShortDataType());
		dataTypeMap.put("ushort", new UnsignedShortDataType());
		dataTypeMap.put("void", new VoidDataType());
		dataTypeMap.put("word", new WordDataType());
		
		// Enums, structs, and unions from debug symbols
		dataTypeMap = importEnums(dataTypeMap);
		dataTypeMap = importStructs(dataTypeMap);
		dataTypeMap = importUnions(dataTypeMap);

		// now create pointer versions of everything that's already in the hashmap
		// also create ** versions of them, since those are used too
		// and ***, in case anything really crazy is going on
		List<String> allKeys = new ArrayList<String>();
		for (String k : dataTypeMap.keySet())
		{
			allKeys.add(k);
		}
		for (String k : allKeys)
		{
			DataType directType = dataTypeMap.get(k);
			PointerDataType pointerType = new PointerDataType(directType);
			dataTypeMap.put(k + " *", pointerType);
			dataTypeMap.put(k + " (*)", pointerType);
			PointerDataType pointerPointerType = new PointerDataType(pointerType);
			dataTypeMap.put(k + " **", pointerPointerType);
			dataTypeMap.put(k + " (**)", pointerPointerType);
			PointerDataType pointerPointerPointerType = new PointerDataType(pointerPointerType);
			dataTypeMap.put(k + " ***", pointerPointerPointerType);
			dataTypeMap.put(k + " (***)", pointerPointerPointerType);
			PointerDataType pointerPointerPointerPointerType = new PointerDataType(pointerPointerPointerType);
			dataTypeMap.put(k + " ****", pointerPointerPointerPointerType);
			dataTypeMap.put(k + " (****)", pointerPointerPointerPointerType);
		}
		
		/* println("Debug: all keys in dataTypeMap: ");
		for (String k : dataTypeMap.keySet())
		{
			println(k);
		} */

		return dataTypeMap;
	}
	
	protected DataType GetDataTypeFromMap(HashMap<String, DataType> dataTypeMap, String typeName)
	{
		DataType result = dataTypeMap.get(typeName);
		if (result == null)
		{
			println("Got null result for DataType name '" + typeName + "'.");
		}
		//else
		//{
		//	println("Found non-null result for DataType name '" + typeName + "'.");
		//}
		return result;
	}

	protected Function DefineFunction(Address functionAddress, String functionName, DataType returnType, VariableStorage returnTypeVariableStorage, Parameter[] params, Function.FunctionUpdateType updateType) throws Exception
	{
		Function func = null;
		String currentTask = "searching for existing function";
		try
		{
			func = getFunctionAt(functionAddress);
			if (func == null)
			{
				currentTask = "creating the basic function";
				this.createFunction(functionAddress, functionName);
				currentTask = "getting the basic function after creation";
				func = getFunctionAt(functionAddress);
			}
			if (func == null)
			{
					println("Couldn't create function " + functionName);
					return null;
			}
			currentTask = "setting the function name";
			func.setName(functionName, SourceType.USER_DEFINED);
			int pNum = 0;
			boolean allParamsValid = true;
			for (Parameter p : params)
			{
				currentTask = "checking parameter " + Integer.toString(pNum) + " for being null";
				if (p == null)
				{
					allParamsValid = false;
					println("Not setting parameters for function '" + functionName + "' because type for parameter " + Integer.toString(pNum) + " is null.");
				}
				pNum++;
			}
			if (allParamsValid)
			{
				currentTask = "calling func.replaceParameters";
				func.replaceParameters(updateType, true, SourceType.USER_DEFINED, params);
			}
			/* set the return type last so that if custom storage is needed, it's already set by the replaceParameters call */
			if (returnType == null)
			{
				println("Not setting return type for function '" + functionName + "' because the returnType is null.");
			}
			else
			{
				currentTask = "checking to see if the return type's VariableStorage is set";
				if (returnTypeVariableStorage == null)
				{
					//println("Not setting return type variable storage for function '" + functionName + "' because returnTypeVariableStorage is null.");
					currentTask = "calling setReturnType (dynamic storage)";
					func.setReturnType(returnType, SourceType.USER_DEFINED);
				}
				else
				{
					//println("Setting return type variable storage for function '" + functionName + "'.");
					currentTask = "calling setReturn (custom storage)";
					func.setReturn(returnType, returnTypeVariableStorage, SourceType.USER_DEFINED);
				}
			}
		}
		catch (Exception e)
		{
			println("Exception thrown while creating or updating function '" + functionName + "', while " + currentTask + ": " + e.getMessage());
			println(e.toString());
			// begin: https://stackoverflow.com/questions/1149703/how-can-i-convert-a-stack-trace-to-a-string
			StringWriter sw = new StringWriter();
			PrintWriter pw = new PrintWriter(sw);
			e.printStackTrace(pw);
			String sStackTrace = sw.toString(); // stack trace as a string
			// end: https://stackoverflow.com/questions/1149703/how-can-i-convert-a-stack-trace-to-a-string
			println(sStackTrace);
			throw e;
		}
		return func;
	}
	
	protected void DefineFunction2(HashMap<String, DataType> dataTypeMap, long functionAddress, String functionName, FunctionParameter returnValue, FunctionParameter[] functionParams, boolean assignReturnAddressStorage, boolean assignAnyStorage)
	{
		boolean gotStorageForAllParams = true;
		boolean createdWithExplicitStorage = false;
		Parameter[] funcParams = new Parameter[functionParams.length];
		VariableStorage returnValueStorage = null;
		String functionAddressAndName = String.format("0x%08X", functionAddress) + " - " + functionName;
		if (assignAnyStorage)
		{
			try
			{
				if (returnValue.IsRegisterParam)
				{
					boolean createWithStorage = true;
					if ((!assignReturnAddressStorage) && (returnValue.RegisterName.equals("ra")))
					{
						createWithStorage = false;
					}
					if (!assignAnyStorage)
					{
						createWithStorage = false;
					}
					if (createWithStorage)
					{
						Register returnValueReg = findRegisterByName(returnValue.RegisterName);
						if (returnValueReg != null)
						{
							returnValueStorage = new VariableStorage(currentProgram, returnValueReg);
							println("Creating return value for function " + functionAddressAndName + " with explicit parameter storage in register " + returnValue.RegisterName);
						}
						else
						{
							createWithStorage = false;
						}
					}
					if (!createWithStorage)
					{
						/* gotStorageForAllParams = false; */
						returnValueStorage = null;
						println("Creating return value for function " + functionAddressAndName + " without explicit parameter storage");
					}
				}
				else
				{
					returnValueStorage = new VariableStorage(currentProgram, returnValue.StackOffset, returnValue.ParameterDataType.getLength());
					println("Creating return value for function " + functionAddressAndName + " with explicit parameter storage at stack offset " + String.valueOf(returnValue.StackOffset));
				}
				for (int i = 0; i < functionParams.length; i++)
				{
					String paramNumAndName = String.valueOf(i) + "(" + functionParams[i].ParameterName + ")";
					if (functionParams[i].IsRegisterParam)
					{
						Register regParam = findRegisterByName(functionParams[i].RegisterName);
						if (regParam != null)
						{
							funcParams[i] = new ParameterImpl(functionParams[i].ParameterName, functionParams[i].ParameterDataType, regParam, currentProgram);
							println("Creating parameter " + paramNumAndName + " for function " + functionAddressAndName + " with explicit parameter storage in register " + functionParams[i].RegisterName);
						}
						else
						{
							funcParams[i] = new ParameterImpl(functionParams[i].ParameterName, functionParams[i].ParameterDataType, currentProgram);
							gotStorageForAllParams = false;
							println("Creating parameter " + paramNumAndName + " for function " + functionAddressAndName + " without explicit parameter storage");
						}
					}
					else
					{
						funcParams[i] = new ParameterImpl(functionParams[i].ParameterName, functionParams[i].ParameterDataType, functionParams[i].StackOffset, currentProgram);
						println("Creating parameter " + paramNumAndName + " for function " + functionAddressAndName + " with explicit parameter storage at stack offset " + String.valueOf(functionParams[i].StackOffset));
					}
				}
				DefineFunction(toAddr(functionAddress), functionName, returnValue.ParameterDataType, returnValueStorage, funcParams, Function.FunctionUpdateType.CUSTOM_STORAGE);
				createdWithExplicitStorage = true;
				println("Created function " + functionAddressAndName + " with explicit parameter storage successfully.");
			}
			catch (Exception e)
			{
				println(e.toString() + " thrown while creating or updating function '" + functionAddressAndName + "': " + e.getMessage() + " - will attempt to create without assigning parameter storage.");
				createdWithExplicitStorage = false;
			}
			if (!gotStorageForAllParams)
			{
				println("Could not determine parameter storage for one or more parameters while creating or updating function '" + functionAddressAndName + "' - will attempt to create without assigning parameter storage.");
				createdWithExplicitStorage = false;
			}
		}
		if (!createdWithExplicitStorage)
		{
			try
			{
				for (int i = 0; i < functionParams.length; i++)
				{
					funcParams[i] = new ParameterImpl(functionParams[i].ParameterName, functionParams[i].ParameterDataType, currentProgram);
				}
				DefineFunction(toAddr(functionAddress), functionName, returnValue.ParameterDataType, null, funcParams, Function.FunctionUpdateType.DYNAMIC_STORAGE_FORMAL_PARAMS);
				println("Created function " + functionAddressAndName + " without explicit parameter storage successfully.");
			}
			catch (Exception e)
			{
				println(e.toString() + " thrown while creating or updating function '" + functionAddressAndName + "': " + e.getMessage() + ", even without explicit storage. Manual repair of the function signature in Ghidra will be required.");
			}
		}
	}
	
	protected DataType getDataTypeByName(String dataTypeName)
	{
		DataType dt = findDataTypeByName(dataTypeName);
		if (dt == null) {
			println("Could not find a datatype with the name '" + dataTypeName + "'");
			return null;
		}
		return dt;
	}
	
	protected HashMap<String, DataType> addDataTypeByName(HashMap<String, DataType> dataTypeMap, String fullGhidraName, String dataTypeName, String typeName)
	{
		DataType dt = findDataTypeByName(fullGhidraName);
		if (dt != null)
		{
			String putText = typeName + " " + dataTypeName;
			if (typeName.equals(""))
			{
				putText = dataTypeName;
			}
			dataTypeMap.put(putText, dt);
		}
		return dataTypeMap;
	}
	
	// from PrintStructureScript.java
	private DataType findDataTypeByName(String name)
	{
		PluginTool tool = state.getTool();
		DataTypeManagerService service = tool.getService(DataTypeManagerService.class);
		DataTypeManager[] dataTypeManagers = service.getDataTypeManagers();
		for (DataTypeManager manager : dataTypeManagers)
		{
			DataType dataType = manager.getDataType(name);
			if (dataType != null)
			{
				return dataType;
			}
		}
		return null;
	}

	private Register findRegisterByName(String name)
	{
		Register result = currentProgram.getProgramContext().getRegister(name);
		if (result == null)
		{
			println("Could not find a register with the name '" + name + "'");
		}
		return result;
	}

	protected void clearExistingData(long startAddress, long length)
	{
		for (long i = startAddress; i < startAddress + length; i++)
		{
			try
			{
				removeDataAt​(toAddr(i));
			}
			catch (Exception e)
			{
				println(e.toString() + " thrown while clearing existing data at address " + String.valueOf(i) + ": " + e.getMessage());
			}
		}
	}
	
	protected void ClearAndCreateData(long startOffset, long typeSize, DataType dt)
	{
		clearExistingData(startOffset, typeSize);
		try
		{
			createData(toAddr(startOffset), dt);
		}
		catch (Exception e)
		{
			println(e.toString() + " thrown while creating data at address " + String.valueOf(startOffset) + ": " + e.getMessage());
		}
	}
	
	protected void SetRegisterAtAddress(ProgramContext ctx, String registerName, long targetAddress, long registerValue)
	{
		try
		{
			Address ta = toAddr(targetAddress);
			BigInteger v = BigInteger.valueOf(registerValue);
			Register r = findRegisterByName(registerName);
			if (r == null)
			{
				println("Error: could not find a register named '" + registerName + "' in order to set its value.");
			}
			else
			{
				RegisterValue val = new RegisterValue(r, v);
				ctx.setRegisterValue(ta, ta, val);
			}
		}
		catch (Exception e)
		{
			println(e.toString() + " thrown while setting register '" + registerName + "' to at offset " + String.valueOf(targetAddress) + ": " + e.getMessage());
		}
	}
	
protected HashMap<String, DataType> importEnums(HashMap<String, DataType> dataTypeMap)
	{
%ENUM_IMPORT_SUBROUTINE_CALLS%
		return dataTypeMap;
	}

protected HashMap<String, DataType> importStructs(HashMap<String, DataType> dataTypeMap)
	{
%STRUCT_IMPORT_SUBROUTINE_CALLS%
		return dataTypeMap;
	}

protected HashMap<String, DataType> importUnions(HashMap<String, DataType> dataTypeMap)
	{
%UNION_IMPORT_SUBROUTINE_CALLS%
		return dataTypeMap;
	}

protected HashMap<String, DataType> importFunctionPointers(HashMap<String, DataType> dataTypeMap)
	{
%FUNCTION_POINTER_IMPORTS%
		return dataTypeMap;
	}

// DataType import subroutines (begin)

// enum import subroutines (begin)
%ENUM_IMPORT_SUBROUTINES%
// enum import subroutines (end)

// struct import subroutines (begin)
%STRUCT_IMPORT_SUBROUTINES%
// struct import subroutines (end)

// union import subroutines (begin)
%UNION_IMPORT_SUBROUTINES%
// union import subroutines (end)

// DataType import subroutines (end)
	
// Function import subroutines for functions without function pointer parameters (begin)
%FUNCTION_IMPORT_SUBROUTINES%
// Function import subroutines for functions without function pointer parameters (end)

// Function import subroutines for functions with function pointer parameters (begin)
%FUNCTION_POINTER_FUNCTION_IMPORT_SUBROUTINES%
// Function import subroutines for functions with function pointer parameters (end)



}