/*****************************************************************************\
*                                                                             *
*  Name   : com_array                                                         *
*  Author : Chris Koeritz                                                     *
*                                                                             *
*******************************************************************************
* Copyright (c) 1998-$now By Author.  This program is free software; you can  *
* redistribute it and/or modify it under the terms of the GNU General Public  *
* License as published by the Free Software Foundation; either version 2 of   *
* the License or (at your option) any later version.  This is online at:      *
*     http://www.fsf.org/copyleft/gpl.html                                    *
* Please send any updates to: fred@gruntose.com                               *
\*****************************************************************************/

#include "com_array.h"

#include <basis/common_templates.h>

// hmmm: add a selectable checking phase for all places that use an index array.

namespace com_extensions {

com_array::com_array() : _implementation(NIL), _type(BYTES) {}

com_array::com_array(com_types type, int dimensions,
    const array<dim_bound> &ranges)
: _implementation(NIL), _type(BYTES)
{ reset(type, dimensions, ranges); }

com_array::com_array(VARTYPE type, int dimensions,
    const array<dim_bound> &ranges)
: _implementation(NIL), _type(BYTES)
{ reset(com_types(type), dimensions, ranges); }

com_array::com_array(com_types type, SAFEARRAY *initial)
: _implementation(initial), _type(type) {}

com_array::com_array(VARTYPE type, SAFEARRAY *initial)
: _implementation(initial), _type(com_types(type)) {}

com_array::com_array(_variant_t &to_snag, bool copy_it)
: _implementation(NIL), _type(BYTES)
{
  if (! (to_snag.vt & VT_ARRAY) ) return;  // not a safe array.
  _type = com_types(to_snag.vt & ~(VT_ARRAY | VT_BYREF));
    // snag the type without special modifiers.
  _variant_t to_use;
  // "to_use" becomes either a copy of "to_snag" or attaches to it.
  if (copy_it) {
    if (VariantCopy(&to_use, &to_snag) != S_OK) return;  // failure.
  } else to_use.Attach(to_snag.Detach());
  // get the right field, depending on whether it's a reference or not.
  if (to_use.vt & VT_BYREF) _implementation = *to_use.pparray;
  else _implementation = to_use.parray;
  to_use.Detach();  // toss its link to the variant.
}

com_array::~com_array()
{
  if (_implementation) SafeArrayDestroy(_implementation);
  _implementation = NIL;
}

SAFEARRAY *com_array::access() { return _implementation; }

const SAFEARRAY *com_array::observe() const { return _implementation; }

int com_array::dimensions() const
{ return _implementation? int(SafeArrayGetDim(_implementation)) : 0; }

int com_array::element_size() const
{ return _implementation? int(SafeArrayGetElemsize(_implementation)) : 0; }

dim_bound com_array::get_bounds(int dimension) const
{
  if (!_implementation) return dim_bound(0, 0);
  long lower, upper;  // the bounds.
  SafeArrayGetLBound(_implementation, dimension + 1, &lower);
  SafeArrayGetUBound(_implementation, dimension + 1, &upper);
  return dim_bound(upper - lower + 1, lower);
}

bool com_array::get(int_array &position, _variant_t &to_examine) const
{
  if (!_implementation) return false;
  if (position.length() < dimensions()) return false;
  byte_array holder(element_size());
  SafeArrayGetElement(_implementation, (long *)position.access(), holder.access());
  switch (_type) {
    case BYTES: to_examine = short(*holder.observe()); break;
    case SHORTS: to_examine = *(short int *)holder.observe(); break;
    case LONGS: to_examine = *(int *)holder.observe(); break;
    case FLOATS: to_examine = *(float *)holder.observe(); break;
    case DOUBLES: to_examine = *(double *)holder.observe(); break;
    case MONEY: to_examine = *(CY *)holder.observe(); break;
    case DATES: to_examine = *(DATE *)holder.observe(); break;
    case STRINGS: to_examine = *(BSTR *)holder.observe(); break;
    case ERRORS: to_examine = *(int *)holder.observe(); break;
    case BOOLS: to_examine = *(short int *)holder.observe(); break;

    // variants are a special case because the returned object is allocated
    // by the SAFEARRAY.
    case VARIANTS: {
      VARIANT tmp = *(VARIANT *)holder.observe();
      to_examine = tmp;
      VariantClear(&tmp);
      break;
    }

    // special case for interface pointers since the SAFEARRAY automatically
    // adds a reference to the interface.
    case UNKNOWNS: {
      IUnknown *to_set = *(IUnknown **)holder.observe();
      to_examine = to_set;
      to_set->Release();
      break;
    }
    case DISPATCHES: {
      IDispatch *to_set = *(IDispatch **)holder.observe();
      to_examine = to_set;
      to_set->Release();
      break;
    }
    default: return false;
  }
  return true;
}

// CONVERT takes the variant passed in and converts to the type held in the
// com_array.
#define CONVERT(type) { \
  type tmp = (type)to_store; \
  SafeArrayPutElement(_implementation, (long *)position.access(), &tmp); \
  break; \
}

// BSTR_CONVERT works a little differently than the normal convert.
#define BSTR_CONVERT(type) { \
  type tmp = to_store; \
  BSTR tmp2 = tmp.copy(); \
  SafeArrayPutElement(_implementation, (long *)position.access(), tmp2); \
  SysFreeString(tmp2); \
  break; \
}

// POINTER_CONVERT operates on simple interface pointers.
#define POINTER_CONVERT(type) { \
  type tmp = (type)to_store; \
  SafeArrayPutElement(_implementation, (long *)position.access(), tmp); \
  tmp->Release(); \
  break; \
}

bool com_array::put(int_array &position, const _variant_t &to_store)
{
  if (!_implementation) return false;
  if (position.length() < dimensions()) return false;
  try {
    switch (_type) {
      case BYTES: CONVERT(BYTE);
      case SHORTS: CONVERT(short int);
      case LONGS: CONVERT(int);
      case FLOATS: CONVERT(float);
      case DOUBLES: CONVERT(double);
      case MONEY: CONVERT(CY);
      case DATES: CONVERT(DATE);
      case STRINGS: BSTR_CONVERT(_bstr_t);
      case ERRORS: CONVERT(int);
      case BOOLS: CONVERT(short int);
      case VARIANTS: CONVERT(VARIANT);
      case UNKNOWNS: POINTER_CONVERT(IUnknown *);
      case DISPATCHES: POINTER_CONVERT(IDispatch *);
      default: return false;
    }
  } catch (...) {
    // if they got to here, then a COM conversion failed.  probably the variant
    // passed in is the wrong type.
    return false;
  }
  return true;
}

void com_array::reset(com_types type, int dimensions,
    const array<dim_bound> &ranges)
{
  if (ranges.length() < dimensions) return;
//hmmm: report error?
  if (_implementation) {
    SafeArrayDestroy(_implementation);
    _implementation = NIL;
  }
  _type = type;
  array<SAFEARRAYBOUND> bounds(dimensions);
  for (int i = 0; i < ranges.length(); i++) {
    bounds[i].lLbound = ranges[i]._lower_bound;
    bounds[i].cElements = ranges[i]._elements;
  }
  _implementation = SafeArrayCreate((VARTYPE)type, dimensions,
      bounds.access());
}

//hmmm: make these catch the return value in case unexpected weirdness happens.
void com_array::lock()
{
  if (_implementation) SafeArrayLock(_implementation);
}

//hmmm: make these catch the return value in case unexpected weirdness happens.
void com_array::unlock()
{
  if (_implementation) SafeArrayUnlock(_implementation);
}

bool com_array::attach(com_types type, SAFEARRAY *initial)
{
  if (_implementation)
    SafeArrayDestroy(_implementation);
  _type = type;
  _implementation = initial;
  return _implementation;
}

int com_array::index(const int_array &to_compute) const
{
  if (to_compute.length() > dimensions()) return 0;  // bad index array.

  // the index is computed by zero-biasing the indices and multiplying
  // by the number of elements in each stage.  the first index is least
  // significant; the later indices get multiplied by all of the previous
  // ranges.
  int index = 0;
  int multiplier = 1;
  for (int i = 0; i < to_compute.length(); i++) {
    dim_bound current = get_bounds(i);
    index += (to_compute[i] - current._lower_bound) * multiplier;
    multiplier *= current._elements;
  }
  return index;
}

void com_array::encapsulate(_variant_t &hold_array, bool copy_it)
{
  hold_array.Clear();
  if (!_implementation) return;  // nil, so leave the output empty.
  SAFEARRAY *to_use;
  if (copy_it) {
    if (SafeArrayCopy(_implementation, &to_use) != S_OK) return;  // failure.
  } else {
    to_use = _implementation;
    _implementation = NIL;
  }
  hold_array.vt = VT_ARRAY | _type;  // fill the type in the variant.
  hold_array.parray = to_use;  // fill the contents.
}

} // namespace

