% ISO Prolog magic square finder (normal magic squares).
%
%   magic_square(N, Ss)
%
%   - N is the grid dimension (N>=1)
%   - Ss is a list of all N×N magic squares using the consecutive integers
%     1..N^2 exactly once.
%   - Each square is represented as a list of N rows, each a list of N integers.
%
% This file is intentionally self-contained (no library dependencies).

magic_square(N, Ss) :-
    findall(S, magic_square_one(N, S), Ss).

% Convenience predicate for benchmarks:
% Count is the number of solutions (equal to length of magic_square/2's list)
% but avoids materializing the full list of squares.
magic_square_len(N, Count) :-
    findall(1, magic_square_one(N, _), Ones),
    length(Ones, Count).

% Nondeterministic generator: yields one magic square at a time.
magic_square_one(N, Square) :-
    integer(N),
    N > 0,
    N2 is N*N,
    range(1, N2, Numbers),
    magic_sum(N, Sum),
    zeros(N, ColSums0),
    fill_rows(1, N, Sum, Numbers, ColSums0, 0, 0, Square),
    true.

magic_sum(N, Sum) :-
    N2 is N*N,
    Sum is (N * (N2 + 1)) // 2.

fill_rows(R, N, Sum, Numbers0, ColSums0, D10, D20, [Row|Rows]) :-
    R < N,
    fill_row(1, R, N, Sum, 0, Numbers0, Numbers1, ColSums0, ColSums1, D10, D11, D20, D21, Row),
    R1 is R + 1,
    fill_rows(R1, N, Sum, Numbers1, ColSums1, D11, D21, Rows).
fill_rows(R, N, Sum, Numbers0, ColSums0, D10, D20, [Row]) :-
    R =:= N,
    fill_last_row(1, N, Sum, Numbers0, Numbers1, ColSums0, ColSums1, D10, D11, D20, D21, Row),
    Numbers1 = [],
    all_eq(ColSums1, Sum),
    D11 =:= Sum,
    D21 =:= Sum.

fill_row(C, R, N, Sum, RowSum0, Numbers0, Numbers1, ColSums0, ColSums1, D10, D11, D20, D21, [V|Vs]) :-
    C =< N,
    ( odd_center_cell(N, R, C, CenterV) ->
        ( C =:= N ->
            V is Sum - RowSum0,
            V =:= CenterV,
            select1(V, Numbers0, NumbersA),
            RowSum1 is Sum
        ;
            V = CenterV,
            select1(V, Numbers0, NumbersA),
            RowSum1 is RowSum0 + V,
            RowSum1 < Sum
        )
    ;
        ( C =:= N ->
            V is Sum - RowSum0,
            select1(V, Numbers0, NumbersA),
            RowSum1 is Sum
        ;
            select1(V, Numbers0, NumbersA),
            RowSum1 is RowSum0 + V,
            RowSum1 < Sum
        )
    ),
    row_bounds_ok(C, N, Sum, RowSum1, NumbersA),
    col_add(C, V, R, N, Sum, ColSums0, ColSumsA),
    col_bounds_ok(C, R, N, Sum, NumbersA, ColSumsA),
    diag_add(R, C, V, N, Sum, D10, D11a, D20, D21a),
    diag_bounds_ok(R, C, N, Sum, NumbersA, D11a, D21a),
    C1 is C + 1,
    fill_row(C1, R, N, Sum, RowSum1, NumbersA, Numbers1, ColSumsA, ColSums1, D11a, D11, D21a, D21, Vs).
fill_row(C, _R, N, _Sum, _RowSum, Numbers, Numbers, ColSums, ColSums, D1, D1, D2, D2, []) :-
    C is N + 1.

fill_last_row(C, N, Sum, Numbers0, Numbers1, [ColSum0|ColSums0], [Sum|ColSums1], D10, D11, D20, D21, [V|Vs]) :-
    C =< N,
    V is Sum - ColSum0,
    select1(V, Numbers0, NumbersA),
    diag_add(N, C, V, N, Sum, D10, D11a, D20, D21a),
    C1 is C + 1,
    fill_last_row(C1, N, Sum, NumbersA, Numbers1, ColSums0, ColSums1, D11a, D11, D21a, D21, Vs).
fill_last_row(C, N, _Sum, Numbers, Numbers, [], [], D1, D1, D2, D2, []) :-
    C is N + 1.

col_add(1, V, R, N, Sum, [S0|Ss], [S1|Ss]) :-
    S1 is S0 + V,
    S1 =< Sum,
    ( R =:= N -> S1 =:= Sum ; true ).
col_add(C, V, R, N, Sum, [S0|Ss0], [S0|Ss1]) :-
    C > 1,
    C1 is C - 1,
    col_add(C1, V, R, N, Sum, Ss0, Ss1).

diag_add(R, C, V, N, Sum, D10, D11, D20, D21) :-
    ( R =:= C -> D11a is D10 + V ; D11a is D10 ),
    C2 is N + 1 - R,
    ( C =:= C2 -> D21a is D20 + V ; D21a is D20 ),
    D11a =< Sum,
    D21a =< Sum,
    D11 = D11a,
    D21 = D21a.

% --- pruning helpers (ISO arithmetic only) ---

row_bounds_ok(C, N, Sum, RowSum, Numbers) :-
    ColsLeft is N - C,
    Target is Sum - RowSum,
    bounds_possible(Numbers, ColsLeft, Target).

col_bounds_ok(C, R, N, Sum, Numbers, ColSums) :-
    RowsLeft is N - R,
    ( RowsLeft =< 0 ->
        true
    ;
        ms_nth1(C, ColSums, ColSum),
        Target is Sum - ColSum,
        bounds_possible(Numbers, RowsLeft, Target)
    ).

diag_bounds_ok(R, C, N, Sum, Numbers, D1, D2) :-
    ( R =:= C ->
        Left is N - R,
        Target is Sum - D1,
        bounds_possible(Numbers, Left, Target)
    ; true
    ),
    C2 is N + 1 - R,
    ( C =:= C2 ->
        Left2 is N - R,
        Target2 is Sum - D2,
        bounds_possible(Numbers, Left2, Target2)
    ; true
    ).

bounds_possible(_Numbers, K, Target) :-
    K =< 0,
    Target =:= 0.
bounds_possible(Numbers, K, Target) :-
    K > 0,
    sum_first_k(Numbers, K, Min),
    sum_last_k(Numbers, K, Max),
    Target >= Min,
    Target =< Max.

sum_first_k(_Numbers, K, 0) :-
    K =< 0.
sum_first_k([X|Xs], K, Sum) :-
    K > 0,
    K1 is K - 1,
    sum_first_k(Xs, K1, Rest),
    Sum is X + Rest.

sum_last_k(Numbers, K, Sum) :-
    length_(Numbers, Len),
    Skip is Len - K,
    drop_(Skip, Numbers, Tail),
    sum_first_k(Tail, K, Sum).

drop_(N, Xs, Xs) :-
    N =< 0.
drop_(N, [_|Xs], Ys) :-
    N > 0,
    N1 is N - 1,
    drop_(N1, Xs, Ys).

length_(Xs, Len) :-
    length_acc_(Xs, 0, Len).

length_acc_([], A, A).
length_acc_([_|Xs], A0, A) :-
    A1 is A0 + 1,
    length_acc_(Xs, A1, A).

ms_nth1(1, [X|_], X).
ms_nth1(N, [_|Xs], X) :-
    N > 1,
    N1 is N - 1,
    ms_nth1(N1, Xs, X).

% Normal magic squares property: for odd N, the center cell is fixed to (N^2+1)/2.
odd_center_cell(N, R, C, CenterV) :-
    1 is N mod 2,
    Mid is (N + 1) // 2,
    R =:= Mid,
    C =:= Mid,
    N2 is N*N,
    CenterV is (N2 + 1) // 2.

% --- Minimal ISO helpers (self-contained): range/3, select1/3, zeros/2, all_eq/2 ---

range(A, B, []) :-
    A > B.
range(A, B, [A|Rest]) :-
    A =< B,
    A1 is A + 1,
    range(A1, B, Rest).

select1(X, [X|Xs], Xs).
select1(X, [Y|Ys], [Y|Zs]) :-
    select1(X, Ys, Zs).

zeros(N, Zs) :-
    zeros_(N, Zs).

zeros_(N, []) :-
    N =< 0.
zeros_(N, [0|Rest]) :-
    N > 0,
    N1 is N - 1,
    zeros_(N1, Rest).

all_eq([], _).
all_eq([X|Xs], V) :-
    X =:= V,
    all_eq(Xs, V).

