:- use_module(library(between)).
:- use_module('./apl.pl').

% Portable regression suite for apl.pl.
%
% Run with:
%   scryer-prolog -f apl_tests.pl -g run_tests -g halt
%   tpl apl_tests.pl -g run_tests

run_tests :-
    Tests = [
        vector_expressions-test_vector_expressions,
        input_type_error-test_input_type_error,
        arithmetic_precedence-test_arithmetic_precedence,
        left_associative_arithmetic-test_left_associative_arithmetic,
        reduce_first_axis-test_reduce_first_axis,
        relation_precedence-test_relation_precedence,
        pick_index_precedence-test_pick_index_precedence,
        vector_bare_refs-test_vector_bare_refs,
        vector_grouped_application-test_vector_grouped_application,
        env_bound_verb-test_env_bound_verb,
        export_tap-test_export_tap,
        simplified_export_mask-test_simplified_export_mask,
        between_integration-test_between_integration,
        bundled_range-test_bundled_range,
        bundled_col-test_bundled_col,
        bundled_transpose-test_bundled_transpose,
        bundled_where-test_bundled_where,
        shorthand_div-test_shorthand_div,
        shorthand_mod-test_shorthand_mod,
        equality_test_mask-test_equality_test_mask,
        zero_test_mask-test_zero_test_mask,
        mask_prefix-test_mask_prefix,
        dif_returns_left_output-test_dif_returns_left_output,
        dif_returns_right_output-test_dif_returns_right_output,
        relational_ground-test_relational_ground
    ],
    run_tests(Tests),
    write('apl tests passed'),
    nl.

run_tests([]).
run_tests([Name-Goal|Tests]) :-
    run_test(Name, Goal),
    run_tests(Tests).

run_test(Name, Goal) :-
    catch(
        (   call(Goal) ->
            true
        ;   throw(error(test_failed(Name, failed), _))
        ),
        Error,
        throw(error(test_failed(Name, Error), _))
    ),
    write('ok '),
    write(Name),
    nl.

must_equal(Expected, Actual) :-
    ( Expected == Actual ->
        true
    ;
        throw(error(assertion_failed(expected(Expected), actual(Actual)), _))
    ).

must_error(Expected, Goal) :-
    catch(
        (   call(Goal),
            throw(error(assertion_failed(expected_error(Expected), actual(success)), _))
        ),
        Error,
        must_match_error(Expected, Error)
    ).

must_match_error(Expected, Actual) :-
    ( Expected = Actual ->
        true
    ;
        throw(error(assertion_failed(expected_error(Expected), actual_error(Actual)), _))
    ).

test_vector_expressions :-
    apl('{1 + 2 3 + 4}', Result),
    must_equal([3,7], Result).

test_input_type_error :-
    must_error(error(type_error(apl_input, _), _), apl(['1', 2], _)).

test_arithmetic_precedence :-
    apl('1 + 2 * 3', Result),
    must_equal(7, Result).

test_left_associative_arithmetic :-
    apl('10 - 3 - 1', Result),
    must_equal(6, Result).

test_reduce_first_axis :-
    apl('+/ {3 3} # i 10', Result),
    must_equal([9,12,15], Result).

test_relation_precedence :-
    apl('?x : 3 + 4', [x=X], Result),
    must_equal(7, Result),
    must_equal(7, X).

test_pick_index_precedence :-
    apl('?pairs @ 2 * i t ?pairs', [pairs=[[1,9],[2,8],[3,7]]], Result),
    must_equal([1,2,3], Result).

test_vector_bare_refs :-
    apl('{?x ?y}', [x=7, y=9], Result),
    must_equal([7,9], Result).

test_vector_grouped_application :-
    apl('{(?f ?x)}', [f='$', x=[[1,2],[3,4]]], Result),
    must_equal([[2,2]], Result).

test_env_bound_verb :-
    apl('?f ?x', [f='$', x=[[1,2],[3,4]]], Result),
    must_equal([2,2], Result).

test_export_tap :-
    findall(
        Result-Exported,
        apl('+/ ^x i 1 range 4', [x=Exported], Result),
        Pairs
    ),
    must_equal(
        [0-[0], 1-[0,1], 3-[0,1,2], 6-[0,1,2,3]],
        Pairs
    ).

test_simplified_export_mask :-
    apl('^mask (1 + i 10) % 3 e 0 & ?mask m (1 + i 10)',
        [mask=Mask],
        Result),
    must_equal([3,6,9], Result),
    must_equal([0,0,1,0,0,1,0,0,1,0], Mask).

test_between_integration :-
    findall(Result, apl('0 between 3', Result), Results),
    must_equal([0,1,2,3], Results).

test_bundled_range :-
    findall(Result, apl:apl('1 range 4', Result), Results),
    must_equal([1,2,3,4], Results).

test_bundled_col :-
    apl('1 col ({3 2} # i 10)', Result),
    must_equal([1,3,5], Result).

test_bundled_transpose :-
    apl('transpose ({2 3} # i 10)', Result),
    must_equal([[0,3],[1,4],[2,5]], Result).

test_bundled_where :-
    apl('where ((1 + i 10) % 3 e 0)', Result),
    must_equal([2,5,8], Result).

test_shorthand_div :-
    apl('7 d 2', Result),
    must_equal(3, Result).

test_shorthand_mod :-
    apl('7 % 3', Result),
    must_equal(1, Result).

test_equality_test_mask :-
    apl('((i 10) % 3) e 0', Result),
    must_equal([1,0,0,1,0,0,1,0,0,1], Result).

test_zero_test_mask :-
    apl('z ((i 10) % 3)', Result),
    must_equal([1,0,0,1,0,0,1,0,0,1], Result).

test_mask_prefix :-
    apl('{1 0 1} m ({3 2} # i 10)', Result),
    must_equal([[0,1],[4,5]], Result).

test_dif_returns_left_output :-
    apl('?x ! 3', [x=X], Result),
    X = 5,
    must_equal(5, Result).

test_dif_returns_right_output :-
    apl('5 ! ?y', [y=Y], Result),
    Y = 7,
    must_equal(7, Result).

test_relational_ground :-
    apl('1 ge 0', Result),
    must_equal(0, Result).
