//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Device/Coord/CoordSystem2D.cpp
//! @brief     Implements ICoordSystem classes.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Device/Coord/CoordSystem2D.h"
#include "Base/Axis/MakeScale.h"
#include "Base/Axis/Scale.h"
#include "Base/Const/Units.h"
#include "Base/Pixel/RectangularPixel.h"
#include "Base/Util/Assert.h"
#include "Base/Vector/GisasDirection.h"
#include <algorithm>
#include <cmath>
#include <numbers>

using std::numbers::pi;

namespace {

double axisAngle(size_t i_axis, R3 kf)
{
    ASSERT(kf != R3());
    if (i_axis == 0)
        // minus sign because sample y and detector u point in opposite directions
        return -std::atan2(kf.y(), kf.x());
    if (i_axis == 1)
        return (pi / 2) - atan2(kf.magxy(), kf.z());
    ASSERT(false);
}

} // namespace

//  ************************************************************************************************
//  class CoordSystem2D
//  ************************************************************************************************

CoordSystem2D::CoordSystem2D(std::vector<const Scale*>&& axes)
    : ICoordSystem(std::move(axes))
{
}

CoordSystem2D::CoordSystem2D(const CoordSystem2D& other)
    : CoordSystem2D(other.m_axes.cloned_vector())
{
}

double CoordSystem2D::calculateMin(size_t i_axis, Coords units) const
{
    ASSERT(i_axis < rank());
    units = substituteDefaultUnits(units);
    if (units == Coords::NBINS)
        return 0.0;
    return calculateValue(i_axis, units, m_axes[i_axis]->min());
}

double CoordSystem2D::calculateMax(size_t i_axis, Coords units) const
{
    ASSERT(i_axis < rank());
    units = substituteDefaultUnits(units);
    if (units == Coords::NBINS)
        return m_axes[i_axis]->size();
    return calculateValue(i_axis, units, m_axes[i_axis]->max());
}

std::vector<Coords> CoordSystem2D::availableUnits() const
{
    return {Coords::NBINS, Coords::RADIANS, Coords::DEGREES};
}

Scale* CoordSystem2D::convertedAxis(size_t i_axis, Coords units) const
{
    const double min = calculateMin(i_axis, units);
    const double max = calculateMax(i_axis, units);
    const auto& axis_name = nameOfAxis(i_axis, units);
    const auto axis_size = m_axes[i_axis]->size();
    return newEquiDivision(axis_name, axis_size, min, max);
}

//  ************************************************************************************************
//  class SphericalCoords
//  ************************************************************************************************

SphericalCoords::SphericalCoords(std::vector<const Scale*>&& axes, const R3& ki)
    : CoordSystem2D(std::move(axes))
    , m_ki(ki)
{
    ASSERT(axes.size() == 2);
}

SphericalCoords::SphericalCoords(const SphericalCoords& other) = default;

SphericalCoords::~SphericalCoords() = default;

SphericalCoords* SphericalCoords::clone() const
{
    return new SphericalCoords(*this);
}

std::vector<Coords> SphericalCoords::availableUnits() const
{
    auto result = CoordSystem2D::availableUnits();
    result.push_back(Coords::QSPACE);
    return result;
}

double SphericalCoords::calculateValue(size_t i_axis, Coords units, double value) const
{
    switch (units) {
    case Coords::RADIANS:
        return value;
    case Coords::DEGREES:
        return Units::rad2deg(value);
    case Coords::QSPACE: {
        if (i_axis == 0) {
            // u axis runs in -y direction
            const R3 kf = vecOfKAlphaPhi(m_ki.mag(), 0.0, value);
            return (m_ki - kf).y();
        }
        if (i_axis == 1) {
            // v axis is perpendicular to ki and y.
            const R3 kf = vecOfKAlphaPhi(m_ki.mag(), value);
            static const R3 unit_v = (m_ki.cross(R3(0, 1, 0))).unit_or_throw();
            return (kf - m_ki).dot(unit_v);
        }
        ASSERT(false);
        return 0;
    }
    default:
        ASSERT(false);
    }
}

std::string SphericalCoords::nameOfAxis(size_t i_axis, const Coords units) const
{
    if (i_axis == 0) {
        switch (units) {
        case Coords::NBINS:
            return "X [nbins]";
        case Coords::RADIANS:
            return "phi_f [rad]";
        case Coords::QSPACE:
            return "Qy [1/nm]";
        case Coords::DEGREES:
        default:
            return "phi_f [deg]";
        }
    }
    if (i_axis == 1) {
        switch (units) {
        case Coords::NBINS:
            return "Y [nbins]";
        case Coords::RADIANS:
            return "alpha_f [rad]";
        case Coords::QSPACE:
            return "Qz [1/nm]";
        case Coords::DEGREES:
        default:
            return "alpha_f [deg]";
        }
    }
    ASSERT(false);
}

//  ************************************************************************************************
//  class ImageCoords
//  ************************************************************************************************

ImageCoords::ImageCoords(std::vector<const Scale*>&& axes, const R3& ki,
                         const RectangularPixel* regionOfInterestPixel)
    : CoordSystem2D(std::move(axes))
    , m_detector_pixel(regionOfInterestPixel)
    , m_ki(ki)
{
    ASSERT(axes.size() == 2);
}

ImageCoords::ImageCoords(const ImageCoords& other)
    : CoordSystem2D(other)
    , m_detector_pixel(other.m_detector_pixel->clone())
    , m_ki(other.m_ki)
{
}

ImageCoords::~ImageCoords() = default;

ImageCoords* ImageCoords::clone() const
{
    return new ImageCoords(*this);
}

std::vector<Coords> ImageCoords::availableUnits() const
{
    auto result = CoordSystem2D::availableUnits();
    result.push_back(Coords::QSPACE);
    result.push_back(Coords::MM);
    return result;
}

double ImageCoords::calculateValue(size_t i_axis, Coords units, double value) const
{
    if (units == Coords::MM)
        return value;
    const auto k00 = m_detector_pixel->getPosition(0.0, 0.0);
    const auto k01 = m_detector_pixel->getPosition(0.0, 1.0);
    const auto k10 = m_detector_pixel->getPosition(1.0, 0.0);
    const auto& max_pos = i_axis == 0 ? k10 : k01; // position of max along given axis
    const double shift = value - m_axes[i_axis]->min();
    const R3 out_dir = k00 + shift * (max_pos - k00).unit_or_throw();
    const R3 kf = out_dir.unit_or_throw() * m_ki.mag();

    switch (units) {
    case Coords::RADIANS:
        return axisAngle(i_axis, kf);
    case Coords::DEGREES:
        return Units::rad2deg(axisAngle(i_axis, kf));
    case Coords::QSPACE: {
        if (i_axis == 0)
            // u axis runs in -y direction
            return (m_ki - kf).y();
        if (i_axis == 1) {
            // v axis is perpendicular to ki and y.
            static const R3 unit_v = (m_ki.cross(R3(0, 1, 0))).unit_or_throw();
            return (kf - m_ki).dot(unit_v);
        }
    } break;
    default:
        break;
    }
    ASSERT(false);
}

std::string ImageCoords::nameOfAxis(size_t i_axis, const Coords units) const
{
    if (i_axis == 0) {
        switch (units) {
        case Coords::NBINS:
            return "X [nbins]";
        case Coords::RADIANS:
            return "phi_f [rad]";
        case Coords::DEGREES:
            return "phi_f [deg]";
        case Coords::QSPACE:
            return "Qy [1/nm]";
        case Coords::MM:
        default:
            return "X [mm]";
        }
    }
    if (i_axis == 1) {
        switch (units) {
        case Coords::NBINS:
            return "Y [nbins]";
        case Coords::RADIANS:
            return "alpha_f [rad]";
        case Coords::DEGREES:
            return "alpha_f [deg]";
        case Coords::MM:
        default:
            return "Y [mm]";
        case Coords::QSPACE:
            return "Qz [1/nm]";
        }
    }
    ASSERT(false);
}

//  ************************************************************************************************
//  class OffspecCoords
//  ************************************************************************************************

OffspecCoords::OffspecCoords(std::vector<const Scale*>&& axes)
    : CoordSystem2D(std::move(axes))
{
}

OffspecCoords::OffspecCoords(const OffspecCoords& other) = default;

OffspecCoords* OffspecCoords::clone() const
{
    return new OffspecCoords(*this);
}

double OffspecCoords::calculateValue(size_t, Coords units, double value) const
{
    switch (units) {
    case Coords::RADIANS:
        return value;
    case Coords::DEGREES:
        return Units::rad2deg(value);
    default:
        ASSERT(false);
    }
}

std::string OffspecCoords::nameOfAxis(size_t i_axis, const Coords units) const
{
    if (i_axis == 0) {
        switch (units) {
        case Coords::NBINS:
            return "X [nbins]";
        case Coords::RADIANS:
            return "alpha_i [rad]";
        case Coords::DEGREES:
        default:
            return "alpha_i [deg]";
        }
    }
    if (i_axis == 1) {
        switch (units) {
        case Coords::NBINS:
            return "Y [nbins]";
        case Coords::RADIANS:
            return "alpha_f [rad]";
        case Coords::DEGREES:
        default:
            return "alpha_f [deg]";
        }
    }
    ASSERT(false);
}

//  ************************************************************************************************
//  class DepthprobeCoords
//  ************************************************************************************************

const std::string z_axis_name = "Position [nm]";

DepthprobeCoords::DepthprobeCoords(std::vector<const Scale*>&& axes, double ki0)
    : CoordSystem2D(std::move(axes))
    , m_ki0(ki0)
{
}

DepthprobeCoords::DepthprobeCoords(const DepthprobeCoords& other) = default;

DepthprobeCoords::~DepthprobeCoords() = default;

DepthprobeCoords* DepthprobeCoords::clone() const
{
    return new DepthprobeCoords(*this);
}

std::vector<Coords> DepthprobeCoords::availableUnits() const
{
    auto result = CoordSystem2D::availableUnits();
    result.push_back(Coords::QSPACE);
    return result;
}

double DepthprobeCoords::calculateValue(size_t i_axis, Coords units, double value) const
{
    const auto& available_units = availableUnits();
    if (std::find(available_units.begin(), available_units.end(), units) == available_units.cend())
        ASSERT(false);

    if (i_axis == 1)
        return value; // unit conversions are not applied to sample position axis
    switch (units) {
    case Coords::RADIANS:
        return value;
    case Coords::QSPACE:
        return 2 * m_ki0 * std::sin(value);
    default:
    case Coords::DEGREES:
        return Units::rad2deg(value);
    }
}

std::string DepthprobeCoords::nameOfAxis(size_t i_axis, const Coords units) const
{
    if (i_axis == 0) {
        switch (units) {
        case Coords::NBINS:
            return "X [nbins]";
        case Coords::RADIANS:
            return "alpha_i [rad]";
        case Coords::QSPACE:
            return "Q [1/nm]";
        case Coords::DEGREES:
        default:
            return "alpha_i [deg]";
        }
    }
    if (i_axis == 1) {
        switch (units) {
        case Coords::NBINS:
            return "Y [nbins]";
        default:
            return "Position [nm]";
        }
    }
    ASSERT(false);
}
