// reportfault.cpp
//
// Test code for ReportFault() API
//
// Claus Brod, http://www.clausbrod.de/Blog

#include <stdio.h>
#include <tchar.h>

#define _WIN32_WINNT 0x0500
#include <windows.h>
#include <ErrorRep.h>

#include <atlbase.h>

#define FATAL_ERROR(s) do { OutputDebugStringA(s); DebugBreak(); } while (0)

static pfn_REPORTFAULT pReportFault = 0;

//--------- Launch debugger ----------------------------------------------------

TCHAR *installed_debugger()
{
  CRegKey key;
  if (ERROR_SUCCESS != key.Open(HKEY_LOCAL_MACHINE, 
    __T("Software\\Microsoft\\Windows NT\\CurrentVersion\\AeDebug\\"), KEY_READ))
    return NULL;

  static TCHAR debuggerPath[1024];
  debuggerPath[0] = 0;
  DWORD dwCount = _countof(debuggerPath);
  key.QueryStringValue(__T("Debugger"), debuggerPath, &dwCount);
  return debuggerPath[0] ? debuggerPath : NULL;
}

static bool launch_debugger(void)
{
  TCHAR *debugger = installed_debugger();
  if (!debugger)
    return false;

  bool debuggerLaunched = false;

  // create event
  SECURITY_ATTRIBUTES attributes = { sizeof(attributes), 0, TRUE };
  HANDLE debugEvent = CreateEvent(&attributes, TRUE, FALSE, NULL);
  if (debugEvent) {
    // get process ID
    DWORD pid = GetCurrentProcessId();

    // build command string
    TCHAR commandBuffer[1024];
    _sntprintf_s(commandBuffer, _countof(commandBuffer), _TRUNCATE,
      debugger, pid, debugEvent);

    // start JIT debugger
    STARTUPINFO sinfo;
    memset(&sinfo, 0, sizeof(sinfo));
    sinfo.cb = sizeof(sinfo);

    PROCESS_INFORMATION pinfo;
    BOOL createProcRet = CreateProcess(NULL, commandBuffer, 
      NULL, NULL, 
      TRUE,  // inherit (event) handles
      0, NULL, NULL, &sinfo, &pinfo);

    if (createProcRet) {
      // wait until debugger fires event
      DWORD dwWaitResult = 0;
      dwWaitResult = WaitForSingleObject(debugEvent, 1000*100);
      if (dwWaitResult == WAIT_OBJECT_0) {
        debuggerLaunched = true;
      }
    } else {
      FATAL_ERROR("Could not run debugger process.\n");
    }
  }

  return debuggerLaunched;
}

//--------- Report crash using ReportFault() -------------------------------------

bool call_report_fault(_EXCEPTION_POINTERS *inExceptionPointer)
{
  if (IsDebuggerPresent()) {
    _tprintf(__T("Debugger is attached.\n"));
    //    return true;
  }

  bool ret = false;

  DWORD dwOpt = 0;
  EFaultRepRetVal repret = (*pReportFault)(inExceptionPointer, dwOpt);
  switch (repret)
  {
  case frrvLaunchDebugger:
    if (launch_debugger()) {
      // stop in debugger
      DebugBreak();
      ret = true;
    }
    break;

  case frrvErrTimeout:
    FATAL_ERROR("launch_debugger(): Timeout detected.\n");
    break;

  case frrvErrNoDW:
    // - Client unable to launch, perform default exception handling
    FATAL_ERROR("launch_debugger(): Client unable to launch");
    break;

  case frrvErr:
    // - ReportFault failed, but client was launched
    break;
  case frrvOk:
  case frrvOkHeadless:
    // - ReportFault succeeded, client launched in silent mode
  case frrvOkManifest:
    // - ReportFault succeeded, client launched in manifest reporting mode
  case frrvOkQueued:
    // - ReportFault succeeded, fault queued for later reporting
    ret = true;
    break;

  default:
    ret = false;
    break;
  }

  return ret;
}

static int filter_exception(EXCEPTION_POINTERS *exc_ptr)
{
  EXCEPTION_RECORD *record = exc_ptr->ExceptionRecord;
  DWORD flags = record->ExceptionFlags;
  _tprintf(__T("filter_exception - flags=%lx\n"), flags);
  call_report_fault(exc_ptr);
  return EXCEPTION_EXECUTE_HANDLER;
}

void wedding_crasher(void)
{
  __try {
    int *foo = (int *)0;
    *foo = 42;
  } __except(filter_exception(GetExceptionInformation())) {
    _tprintf(__T("Now in exception handler, process is still alive!\n"));
  }
  Sleep(5000);
}

int _tmain(void)
{
  HMODULE hMod = LoadLibrary(__T("FaultRep.dll"));
  if (hMod)
    pReportFault = (pfn_REPORTFAULT)GetProcAddress(hMod, "ReportFault");

  if (pReportFault)
    wedding_crasher();
  return 1;
}
