//
// Copyright (C) 2011-13 Mark Wiebe, DyND Developers
// BSD 2-Clause License, see LICENSE.txt
//

#include <stdexcept>
#include <algorithm>

#include <dynd/type_promotion.hpp>
#include <dynd/types/string_type.hpp>

using namespace std;
using namespace dynd;

/*
static intptr_t min_strlen_for_builtin_kind(type_kind_t kind)
{
    switch (kind) {
        case bool_kind:
            return 1;
        case int_kind:
        case uint_kind:
            return 24;
        case real_kind:
            return 32;
        case complex_kind:
            return 64;
        default:
            throw runtime_error("cannot get minimum string length for specified kind");
    }
}
*/

ndt::type dynd::promote_types_arithmetic(const ndt::type& tp0, const ndt::type& tp1)
{
    // Use the value types
    const ndt::type& tp0_val = tp0.value_type();
    const ndt::type& tp1_val = tp1.value_type();

    //cout << "Doing type promotion with value types " << tp0_val << " and " << tp1_val << endl;

    if (tp0_val.is_builtin() && tp1_val.is_builtin()) {
        const size_t int_size = sizeof(int);
        switch (tp0_val.get_kind()) {
            case bool_kind:
                switch (tp1_val.get_kind()) {
                    case bool_kind:
                        return ndt::make_type<int>();
                    case int_kind:
                    case uint_kind:
                        return (tp1_val.get_data_size() >= int_size) ? tp1_val
                                                               : ndt::make_type<int>();
                    case void_kind:
                        return tp0_val;
                    case real_kind:
                        // The bool type doesn't affect float type sizes, except
                        // require at least float32
                        return tp1_val.unchecked_get_builtin_type_id() != float16_type_id
                                        ? tp1_val : ndt::make_type<float>();
                    default:
                        return tp1_val;
                }
            case int_kind:
                switch (tp1_val.get_kind()) {
                    case bool_kind:
                        return (tp0_val.get_data_size() >= int_size) ? tp0_val
                                                               : ndt::make_type<int>();
                    case int_kind:
                        if (tp0_val.get_data_size() < int_size && tp1_val.get_data_size() < int_size) {
                            return ndt::make_type<int>();
                        } else {
                            return (tp0_val.get_data_size() >= tp1_val.get_data_size()) ? tp0_val
                                                                              : tp1_val;
                        }
                    case uint_kind:
                        if (tp0_val.get_data_size() < int_size && tp1_val.get_data_size() < int_size) {
                            return ndt::make_type<int>();
                        } else {
                            // When the element_sizes are equal, the uint kind wins
                            return (tp0_val.get_data_size() > tp1_val.get_data_size()) ? tp0_val
                                                                             : tp1_val;
                        }
                    case real_kind:
                        // Integer type sizes don't affect float type sizes, except
                        // require at least float32
                        return tp1_val.unchecked_get_builtin_type_id() != float16_type_id
                                        ? tp1_val : ndt::make_type<float>();
                    case complex_kind:
                        // Integer type sizes don't affect complex type sizes
                        return tp1_val;
                    case void_kind:
                        return tp0_val;
                    default:
                        break;
                }
                break;
            case uint_kind:
                switch (tp1_val.get_kind()) {
                    case bool_kind:
                        return (tp0_val.get_data_size() >= int_size) ? tp0_val
                                                               : ndt::make_type<int>();
                    case int_kind:
                        if (tp0_val.get_data_size() < int_size && tp1_val.get_data_size() < int_size) {
                            return ndt::make_type<int>();
                        } else {
                            // When the element_sizes are equal, the uint kind wins
                            return (tp0_val.get_data_size() >= tp1_val.get_data_size()) ? tp0_val
                                                                              : tp1_val;
                        }
                    case uint_kind:
                        if (tp0_val.get_data_size() < int_size && tp1_val.get_data_size() < int_size) {
                            return ndt::make_type<int>();
                        } else {
                            return (tp0_val.get_data_size() >= tp1_val.get_data_size()) ? tp0_val
                                                                              : tp1_val;
                        }
                    case real_kind:
                        // Integer type sizes don't affect float type sizes, except
                        // require at least float32
                        return tp1_val.unchecked_get_builtin_type_id() != float16_type_id
                                        ? tp1_val : ndt::make_type<float>();
                    case complex_kind:
                        // Integer type sizes don't affect complex type sizes
                        return tp1_val;
                    case void_kind:
                        return tp0_val;
                    default:
                        break;
                }
                break;
            case real_kind:
                switch (tp1_val.get_kind()) {
                    // Integer type sizes don't affect float type sizes
                    case bool_kind:
                    case int_kind:
                    case uint_kind:
                        return tp0_val;
                    case real_kind:
                        return ndt::type(max(max(tp0_val.unchecked_get_builtin_type_id(),
                                        tp1_val.unchecked_get_builtin_type_id()), float32_type_id));
                    case complex_kind:
                        if (tp0_val.get_type_id() == float64_type_id && tp1_val.get_type_id() == complex_float32_type_id) {
                            return ndt::type(complex_float64_type_id);
                        } else {
                            return tp1_val;
                        }
                    case void_kind:
                        return tp0_val;
                    default:
                        break;
                }
                break;
            case complex_kind:
                switch (tp1_val.get_kind()) {
                    // Integer and float type sizes don't affect complex type sizes
                    case bool_kind:
                    case int_kind:
                    case uint_kind:
                    case real_kind:
                        if (tp0_val.unchecked_get_builtin_type_id() == complex_float32_type_id &&
                                        tp1_val.unchecked_get_builtin_type_id() == float64_type_id) {
                            return ndt::type(complex_float64_type_id);
                        } else {
                            return tp0_val;
                        }
                    case complex_kind:
                        return (tp0_val.get_data_size() >= tp1_val.get_data_size()) ? tp0_val
                                                                          : tp1_val;
                    case void_kind:
                        return tp0_val;
                    default:
                        break;
                }
                break;
            case void_kind:
                return tp1_val;
            default:
                break;
        }

        stringstream ss;
        ss << "internal error in built-in dynd type promotion of " << tp0_val << " and " << tp1_val;
        throw std::runtime_error(ss.str());
    }

    // HACK for getting simple string type promotions.
    // TODO: Do this properly in a pluggable manner.
    if ((tp0_val.get_type_id() == string_type_id ||
                    tp0_val.get_type_id() == fixedstring_type_id) &&
                (tp1_val.get_type_id() == string_type_id ||
                    tp1_val.get_type_id() == fixedstring_type_id)) {
        // Always promote to the default utf-8 string (for now, maybe return encoding, etc later?)
        return ndt::make_string();
    }

    // type, string -> type
    if (tp0_val.get_type_id() == type_type_id && tp1_val.get_kind() == string_kind) {
        return tp0_val;
    }
    // string, type -> type
    if (tp0_val.get_kind() == string_kind && tp1_val.get_type_id() == type_type_id) {
        return tp1_val;
    }

    // In general, if one type is void, just return the other type
    if (tp0_val.get_type_id() == void_type_id) {
        return tp1_val;
    } else if (tp1_val.get_type_id() == void_type_id) {
        return tp0_val;
    }

    stringstream ss;
    ss << "type promotion of " << tp0 << " and " << tp1 << " is not yet supported";
    throw std::runtime_error(ss.str());
}
