classdef ArithmeticParser
    % ArithmeticParser Class for parsing strings representig arithmetic
    % operations on numbers and various system properties.
    methods(Static)
        function ast = parse(expr)
            % This is the main function of the class.
            % inputs:
            % expr : a string with numbers, arithmetic operations, or
            %        names of properties. this string is tokenized, for the
            %        creation of the AST, so, this function also accepts an
            %        array of tokens. each token MUST have the properties:
            %        'type' and 'value', with values like the ones the
            %        function `tokenize` puts on the tokens.
            %
            % outputs:
            % ast :  an AST (Abstract Syntax Tree) representation of
            %        `expr`. this is a binary tree, implemented with
            %        `struct` objects, that gives the correct order of
            %        operations.
            %
            % for example, the expr: "9 + (3^(4-1)) * 5 + P_m" will look
            % something like this:
            %             +
            %           /   \
            %          +    P_m
            %         / \
            %        9   *
            %           / \
            %          ^   5
            %         / \
            %        3   -
            %           / \
            %          4   1
            % limitations:
            %       - property names MUST NOT start with digits
            %       - there MUST be an operation, closing parantheses or
            %       the end of the line after a property name. the parsing
            %       of property names only stops at operations.
            %       - parsing of property names is pretty delicate.
            %       problems can be avoided by creating good tokens, and
            %       passing the token array straight to this function.
            %
            if isstring(expr)
                expr = char(expr);  % Convert string to char array
            end
            if ischar(expr)
                tokens = ArithmeticParser.tokenize(expr);
            elseif iscell(expr)
                tokens = expr;
            else
                error("Invalid expression type");
            end
            
            if not(ArithmeticParser.validate(tokens))
                error("invalid expression");
            end
            [ast, ~] = ArithmeticParser.parseAddition(tokens, 1);
        end

        function printTree(ast, depth, prefix)
            % This function prints the AST that parse() creates.
            if nargin < 2
                depth = 0;
                prefix = "";
            end

            for i=0:depth
                fprintf('  ');
            end
            fprintf("%s %s: %s\n", prefix, string(ast.type), string(ast.value));
            if strcmp(ast.type, 'OPERATION') || strcmp(ast.type, 'UNARY')
                ArithmeticParser.printTree(ast.left, depth + 1, "L:");
                if isstruct(ast.right)
                    ArithmeticParser.printTree(ast.right, depth + 1,"R:");
                end
            end
        end

        function f = compileAST(ast)
            %compileAST turns the AST to a function
            if strcmp(ast.type, 'PROPERTY')
                f = @(system, component_manager, H_before) ArithmeticParser.parse_property(ast.value, system, component_manager, H_before);
            elseif strcmp(ast.type, 'NUMBER')
                if isstring(ast.value) || ischar(ast.value)
                    value = str2double(ast.value);
                else
                    value = ast.value;
                end
                f = @(system, component_manager, H_before) value;
            else
                left = ArithmeticParser.compileAST(ast.left);
                if isstruct(ast.right)
                    right = ArithmeticParser.compileAST(ast.right);
                end
                switch ast.value
                    case 'PLUS'
                        f = @(system, component_manager, H_before) left(system, component_manager, H_before) + right(system, component_manager, H_before);
                    case 'MINUS'
                        f = @(system, component_manager, H_before) left(system, component_manager, H_before) - right(system, component_manager, H_before);
                    case 'MUL'
                        f = @(system, component_manager, H_before) left(system, component_manager, H_before) * right(system, component_manager, H_before);
                    case 'DIV'
                        f = @(system, component_manager, H_before) left(system, component_manager, H_before) / right(system, component_manager, H_before);
                    case 'UMINUS'
                        f = @(system, component_manager, H_before) -(left(system, component_manager, H_before));               
                    case 'FACTOR'
                        f = @(system, component_manager, H_before) left(system, component_manager, H_before) ^ right(system, component_manager, H_before);
                    case 'Real'
                        f = @(system, component_manager, H_before) real(left(system, component_manager, H_before));
                    case 'Imag'
                        f = @(system, component_manager, H_before) imag(left(system, component_manager, H_before));
                    case 'Conj'
                        f = @(system, component_manager, H_before) conj(left(system, component_manager, H_before));
                    case 'Abs'
                        f = @(system, component_manager, H_before) abs(left(system, component_manager, H_before));
                    case 'sin'
                        f = @(system, component_manager, H_before) sin(left(system, component_manager, H_before));
                    case 'cos'
                        f = @(system, component_manager, H_before) cos(left(system, component_manager, H_before));
                    otherwise
                        error("unknown token value: " + ast.value);
                end
            end
        end
    end

    methods (Static, Access = public)
        function [isValid] = areParanthesisBalanced(tokens)
            isValid = true;
            paranthesis_depth = 0;  % paranthesis depth
            for i=1:numel(tokens)
                if strcmp(tokens{i}.type, 'LPAREN')
                    paranthesis_depth = paranthesis_depth + 1;
                elseif strcmp(tokens{i}.type, 'RPAREN')
                    paranthesis_depth = paranthesis_depth - 1;
                end
                if paranthesis_depth < 0
                    isValid = false;
                    return;
                end
            end
            if paranthesis_depth ~= 0
                isValid = false;
            end
        end
        function [isValid] = areTokensOrdered(tokens)
            %% areTokensOrdered - make sure token types play nicely
            % after each type of token, only a certain few other types can
            % follow. for instance, there can't be a "+" binary operation
            % right after another binary operation, of a left paranthesis.
            if isscalar(tokens)
                isValid = strcmp(tokens{1}.type, 'NUMBER') || ...
                    strcmp(tokens{1}.type, 'PROPERTY');
                return;
            end
            for i=1:numel(tokens) - 1
                current = tokens{i};
                next = tokens{i+1};
                if strcmp(current.type, 'NUMBER') || ...
                        strcmp(current.type, 'PROPERTY')
                    isValid = strcmp(next.type, 'BINARY-OP') || ...
                        strcmp(next.type, 'RPAREN');
                elseif strcmp(current.type, 'UNARY-OP') || ...
                        strcmp(current.type, 'BINARY-OP') || ...
                        strcmp(current.type, 'LPAREN')
                    isValid = strcmp(next.type, 'NUMBER') || ...
                        strcmp(next.type, 'PROPERTY') || ...
                        strcmp(next.type, 'UNARY-OP') || ...
                        strcmp(next.type, 'LPAREN');
                elseif strcmp(current.type, 'RPAREN')
                    isValid = strcmp(next.type, 'BINARY-OP') || ...
                        strcmp(next.type, 'RPAREN');
                else
                    error("Unknown token type: " + current.type);
                end
                if not(isValid)
                    return;
                end
            end
            % function can't end with an operation or left paranthesis
            if strcmp(next.type, 'UNARY-OP') || ...
                    strcmp(next.type, 'BINARY-OP') || ...
                    strcmp(next.type, 'LPAREN')
                isValid = false;
            end
        end
        function [isValid] = validate(tokens)
            %% validate - make sure a token array can create a valid function
            % it operates by checking two things:
            %   1. left and right paranthesis are balanced
            %   2. the order of the tokens makes sense
            isValid = ArithmeticParser.areParanthesisBalanced(tokens) && ...
                ArithmeticParser.areTokensOrdered(tokens);
        end

        function [value] = parse_property(property_string, system, component_manager, H_before)
            %% parse_property - turns a string from AST to system/component properties
            % this function determines wether the property is a system property, that
            % can be obtained with TA_System.Evaluate_Property, or a component related
            % value, that can be obtained with Mixture_Property.
            % another case is the H_before system property, that functions might use.
            mixturePropertyNames = ["nu", "alpha", "D", "rho", "Pr", "Sc", "gamma", "M", "cp_mol", "cp", "k", "sp", "lh", "Tref", "CB", "lh_m"];
            split_property_string = strsplit(property_string, '#');
            info_array = cell(1, numel(split_property_string) - 1);
            for i=1:numel(split_property_string)
                token = split_property_string{i};
                if all(isstrprop(token, 'digit'))
                    info_array{i} = str2double(token);
                else
                    info_array{i} = token;
                end
            end

            if strcmpi(split_property_string(1), "h_before")
                value = H_before;
            elseif startsWith(split_property_string(1), "System")
                if strcmp(split_property_string{2}, 'Begin')
                    info_array = {'Begin', str2num(split_property_string{3})};
                else
                    info_array = info_array(2:end);
                end
                value = system.Evaluate_Property(info_array);
            elseif startsWith(split_property_string(1), "Component")
                component_num = sscanf(char(split_property_string(1)), 'Component%d');
                component = component_manager.relevant_components{component_num};
                info_array{1} = component;
                if ismember(split_property_string{2}, mixturePropertyNames)  % mixed property
                    [nu, alpha, D, rho, Pr, Sc, gamma, M, cp_mol, cp, k, sp, lh, Tref, CB, lh_m] =  ...
                        Mixture_Properties(system.P_m,component.Temperature(end),system.Dry_Switch,system.Mixture_Array);
                    values = {nu, alpha, D, rho, Pr, Sc, gamma, M, cp_mol, cp, k, sp, lh, Tref, CB, lh_m};
                    props = cell2struct(values, mixturePropertyNames, 2);
                    if isfield(props, split_property_string{2})
                        value = props.(split_property_string{2});
                    else
                        error("Unknown mixture property: %s", split_property_string(2));
                    end
                else  % regular component property
                    value = system.Evaluate_Property(info_array);
                end
            else
                error("Wrong property string format");
            end
        end
        function tokens = tokenize(expr)
            expr = regexprep(expr, '\s+', '');
            tokens = {};
            i = 1;
            while i <= length(expr)
                c = expr(i);
                if isstrprop(c, 'digit') || strcmp(c, 'i')  % number token
                    numStr = c;
                    i = i + 1;
                    acceptDecimal = true;
                    while i <= length(expr) && ...
                            (isstrprop(expr(i), 'digit') || ...
                            (expr(i) == '.' && acceptDecimal) || ...
                            strcmp(expr(i), 'i'))
                        if expr(i) == '.'
                            acceptDecimal = false;
                        end
                        numStr = [numStr expr(i)];
                        i = i + 1;
                    end
                    tokens{end+1} = struct('type', 'NUMBER', 'value', str2double(numStr));
                elseif ismember(c, '^+-*/')  % binary operation token
                    % decide unary minus by context
                    if c == '-' && (isempty(tokens) || strcmp(tokens{end}.type,'BINARY-OP') || strcmp(tokens{end}.type,'LPAREN') || strcmp(tokens{end}.type,'UNARY-OP'))
                        tokens{end+1} = struct('type','UNARY-OP','value','UMINUS');
                        i = i + 1;
                        continue;
                    end
                    switch c
                        case '+', t = 'PLUS';
                        case '-', t = 'MINUS';
                        case '*', t = 'MUL';
                        case '/', t = 'DIV';
                        case '^', t = 'FACTOR';
                    end
                    tokens{end+1} = struct('type', 'BINARY-OP', 'value', t);
                    i = i + 1;
                elseif ismember(c, '()')  % paranthesis token
                    switch c
                        case '(', t = 'LPAREN';
                        case ')', t = 'RPAREN';
                    end
                    tokens{end+1} = struct('type', t, 'value', t);
                    i = i + 1;
                elseif ismember(c, '|')  % unary operation token
                    i = i + 1;
                    switch expr(i:i+2)
                        case 'REA', t = 'Real';
                        case 'IMA', t = 'Imag';
                        case 'CON', t = 'Conj';
                        case 'ABS', t = 'Abs';
                        case 'SIN', t = 'sin';
                        case 'COS', t = 'cos';
                        otherwise
                            error("Invalid Unary operation");
                    end
                    tokens{end+1} = struct('type', 'UNARY-OP', 'value', t);
                    i = i + 3;
                else
                    propertyStr = c;
                    i = i + 1;
                    while i <= length(expr) && ...  % property token
                            ismember(expr(i), '^+-*/()|')==false
                        propertyStr = [propertyStr expr(i)];
                        i = i + 1;
                    end
                    tokens{end+1} = struct('type', 'PROPERTY', 'value', propertyStr);
                end
            end
        end

        function [node, pos] = parseAddition(tokens, pos)
            [left, pos] = ArithmeticParser.parseMultiplication(tokens, pos);
            while pos <= length(tokens) && ...
                    strcmp(tokens{pos}.type, 'BINARY-OP') && ...
                    (strcmp(tokens{pos}.value, 'PLUS') || strcmp(tokens{pos}.value, 'MINUS'))
                op = tokens{pos}.value;
                pos = pos + 1;
                [right, pos] = ArithmeticParser.parseMultiplication(tokens, pos);
                left = struct('type', 'OPERATION', 'value', op, 'left', left, 'right', right);
            end
            node = left;
        end

        function [node, pos] = parseMultiplication(tokens, pos)
            [left, pos] = ArithmeticParser.parseUminus(tokens, pos);
            while pos <= length(tokens) && ...
                    strcmp(tokens{pos}.type, 'BINARY-OP') && ...
                    (strcmp(tokens{pos}.value, 'MUL') || strcmp(tokens{pos}.value, 'DIV'))
                op = tokens{pos}.value;
                pos = pos + 1;
                [right, pos] = ArithmeticParser.parseUminus(tokens, pos);
                left = struct('type', 'OPERATION', 'value', op, 'left', left, 'right', right);
            end
            node = left;
        end

        function [node, pos] = parseUminus(tokens, pos)
            if pos <= length(tokens) && ...
                    strcmp(tokens{pos}.type, 'UNARY-OP') && ...
                    strcmp(tokens{pos}.value, 'UMINUS')
                pos = pos + 1;
                [left, pos] = ArithmeticParser.parseUminus(tokens, pos);
                left = struct('type', 'UNARY', 'value', 'UMINUS', 'left', left, 'right', false);
            else
                [left, pos] = ArithmeticParser.parseFactor(tokens, pos);
            end
            node = left;
        end

        function [node, pos] = parseFactor(tokens, pos)
            [left, pos] = ArithmeticParser.parseParentheses(tokens, pos);
            while pos <= length(tokens) && ...
                    strcmp(tokens{pos}.type, 'BINARY-OP') && ...
                    (strcmp(tokens{pos}.value, 'FACTOR'))
                op = tokens{pos}.value;
                pos = pos + 1;
                [right, pos] = ArithmeticParser.parseParentheses(tokens, pos);
                left = struct('type', 'OPERATION', 'value', op, 'left', left, 'right', right);
            end
            node = left;
        end

        function [node, pos] = parseParentheses(tokens, pos)
            token = tokens{pos};
            if strcmp(token.type, 'NUMBER')
                node = struct('type', 'NUMBER', 'value', token.value);
                pos = pos + 1;
            elseif strcmp(token.type, 'PROPERTY')
                node = struct('type', 'PROPERTY', 'value', token.value);
                pos = pos + 1;
            elseif strcmp(token.type, 'UNARY-OP') && ~strcmp(token.value, 'UMINUS')
                pos = pos + 1;
                [left, pos] = ArithmeticParser.parseParentheses(tokens, pos);
                node = struct('type', 'UNARY', 'value', token.value, 'left', left, 'right', false);
            elseif strcmp(token.type, 'LPAREN')
                pos = pos + 1;
                [node, pos] = ArithmeticParser.parseAddition(tokens, pos);
                if pos > length(tokens) || ~strcmp(tokens{pos}.type, 'RPAREN')
                    error('Expected closing parenthesis');
                end
                pos = pos + 1;
            else
                error('Expected number or parenthesis');
            end
        end
    end
end