import numpy as np
import pandas as pd
import uuid
import math
from collections import defaultdict
from evaluation import get_actual_demand
from utils import load_problem_data,save_solution
from seeds import known_seeds



class ServerManager:
    def __init__(self):
        # Load data
        self.demand, self.datacenters, self.servers, self.selling_prices, self.elasticity = load_problem_data()
        
        # Convert server info to dictionary for O(1) lookup
        self.server_info_dict = self.servers.set_index('server_generation').to_dict('index')
        # Convert datacenter info to dictionary for O(1) lookup
        self.datacenter_info_dict = self.datacenters.set_index('datacenter_id').to_dict('index')
        
        # Pre-group and sort datacenters by latency type
        self.datacenters_by_latency = {}
        for latency in ['low', 'medium', 'high']:
            dcs = self.datacenters[self.datacenters['latency_sensitivity'] == latency]
            dcs = dcs.sort_values('cost_of_energy')
            self.datacenters_by_latency[latency] = dcs
        
        # Initialize server list and related data structures
        self.existing_servers = []  # List of existing servers (each server is a dict)
        self.server_optimal_cost = {}  # Optimal cost per server generation
        self.server_optimal_dismiss_time = {}  # Optimal dismiss time for each server generation
        self.server_optimal_unit_cost = {}  # Optimal cost per slot
        self.server_revenue_rate = {}  # Revenue rate per server generation and latency type
        self.server_revenue_perslot = {} # Revenue per slot per server generation and latency type
        self.solution = []  # Overall solution record
        self.server_sizes = {} # Cache server sizes
        self.pricing_strategy = []  # Pricing strategy records
        
        # 最佳价格调整表 (硬编码)
        self.price_adjustment = {
            ('CPU.S1', 'low'): 0.13,
            ('CPU.S2', 'low'): 0.05,
            ('CPU.S3', 'low'): -0.11,
            ('CPU.S4', 'low'): -0.15,
            ('GPU.S1', 'low'): 0.01,
            ('GPU.S2', 'low'): -0.15,
            ('GPU.S3', 'low'): -0.13,
            ('CPU.S1', 'medium'): -0.07,
            ('CPU.S2', 'medium'): -0.17,
            ('CPU.S3', 'medium'): -0.15,
            ('CPU.S4', 'medium'): -0.17,
            ('GPU.S1', 'medium'): -0.21,
            ('GPU.S2', 'medium'): -0.21,
            ('GPU.S3', 'medium'): -0.19,
            ('CPU.S1', 'high'): 0.13,
            ('CPU.S2', 'high'): -0.09,
            ('CPU.S3', 'high'): -0.05,
            ('CPU.S4', 'high'): -0.15,
            ('GPU.S1', 'high'): 0.01,
            ('GPU.S2', 'high'): 0.05,
            ('GPU.S3', 'high'): 0.01
        }
        
        # Preprocess server data
        self._preprocess_servers()
    
    def _preprocess_servers(self):
        """Preprocesses server data, calculates optimal cost and revenue rates."""
        for gen, server in self.server_info_dict.items():
            cumulative_maintenance = 0
            min_avg_cost = float('inf')
            best_dismiss_time = 1
            
            for time_step in range(1, server['life_expectancy'] + 1):
                life_ratio = time_step / server['life_expectancy']
                if life_ratio > 0:
                    maintenance_cost = server['average_maintenance_fee'] * (1 + 1.5 * life_ratio * math.log2(1.5 * life_ratio))
                else:
                    maintenance_cost = server['average_maintenance_fee']
                
                cumulative_maintenance += maintenance_cost
                
                avg_cost = (cumulative_maintenance + server['purchase_price']) / time_step
                
                if avg_cost < min_avg_cost:
                    min_avg_cost = avg_cost
                    best_dismiss_time = time_step
            
            self.server_optimal_cost[gen] = min_avg_cost
            self.server_optimal_unit_cost[gen] = min_avg_cost / server['slots_size']
            self.server_optimal_dismiss_time[gen] = best_dismiss_time
            self.server_sizes[gen] = server['slots_size'] # Cache server size here

            self.server_revenue_rate[gen] = {}
            self.server_revenue_perslot[gen] = {}
            for latency in ['low', 'medium', 'high']:
                price_rows = self.selling_prices[
                    (self.selling_prices['server_generation'] == gen) & 
                    (self.selling_prices['latency_sensitivity'] == latency)
                ]
                
                selling_price = price_rows['selling_price'].iloc[0] if len(price_rows) > 0 else 0
                
                # 应用价格调整
                adj_percent = self.price_adjustment.get((gen, latency), 0.0)
                adjusted_price = selling_price * (1 + adj_percent)
                
                unit_revenue = server['capacity'] * adjusted_price / server['slots_size']
                
                if self.server_optimal_unit_cost[gen] > 0:
                    revenue_rate = unit_revenue / self.server_optimal_unit_cost[gen] - 1
                    revenue_perslot = unit_revenue - self.server_optimal_unit_cost[gen]
                else:
                    revenue_rate = 0
                    revenue_perslot = 0

                self.server_revenue_rate[gen][latency] = revenue_rate
                self.server_revenue_perslot[gen][latency] = revenue_perslot

    def _parse_release_time(self, release_time_str):
        """Parses server availability time range."""
        time_range = release_time_str.strip('[]').split(',')
        start_time = int(time_range[0])
        end_time = int(time_range[1])
        return start_time, end_time
    
    def _get_server_info(self, server_generation):
        """Retrieves server information."""
        return self.server_info_dict[server_generation]
    
    def _get_datacenter_info(self, datacenter_id):
        """Retrieves datacenter information."""
        return self.datacenter_info_dict[datacenter_id]
    
    def process_time_step(self, time_step, actual_demand_dict):
        """Processes a single time step."""
        #print(f"Processing time step {time_step}...")
        self.solution_dismiss_step = []
        self.solution_move_step = []
        self.solution_buy_step = []

        # 1. Handle existing servers' lifecycle
        # Identify servers to dismiss and build a new list of active servers
        new_existing_servers = []
        servers_to_dismiss_this_step = []

        for server in self.existing_servers:
            server_info = self._get_server_info(server['server_generation'])
            server['age'] += 1
            
            if (server['age'] >= server_info['life_expectancy'] or 
                server['age'] >= self.server_optimal_dismiss_time[server['server_generation']]):
                servers_to_dismiss_this_step.append(server)
            else:
                new_existing_servers.append(server)
        
        # Update existing_servers after identifying all for dismissal
        self.existing_servers = new_existing_servers

        for server in servers_to_dismiss_this_step:
            self.solution_dismiss_step.append({
                'time_step': time_step,
                'datacenter_id': server['datacenter_id'],
                'server_generation': server['server_generation'],
                'server_id': server['server_id'],
                'action': 'dismiss'
            })

        # Build server index after initial dismissals
        servers_by_type_dc = defaultdict(lambda: defaultdict(list)) # {(gen, dc_id): {server_id: server_obj}}
        servers_by_type = defaultdict(list) # {gen: [server_obj]}

        # Map server_id to its index in existing_servers for updates
        server_id_to_obj_map = {} 

        for server_obj in self.existing_servers:
            key_type_dc = (server_obj['server_generation'], server_obj['datacenter_id'])
            servers_by_type_dc[key_type_dc][server_obj['server_id']].append(server_obj)
            servers_by_type[server_obj['server_generation']].append(server_obj)
            server_id_to_obj_map[server_obj['server_id']] = server_obj
        
        # 2. Calculate datacenter remaining slots
        datacenter_remaining_slots = {
            dc_id: self.datacenter_info_dict[dc_id]['slots_capacity'] 
            for dc_id in self.datacenter_info_dict
        }
        
        # Subtract slots occupied by existing servers
        for server in self.existing_servers:
            datacenter_remaining_slots[server['datacenter_id']] -= self.server_sizes[server['server_generation']]

        # 3. Sort server types by revenue per slot (descending)
        server_types_sorted_by_revenue = []
        for gen in self.server_info_dict.keys():
            for latency in ['low', 'medium', 'high']:
                # 使用实际需求而不是预测需求
                if (gen, latency) in actual_demand_dict:
                    server_types_sorted_by_revenue.append((gen, latency, self.server_revenue_perslot[gen][latency]))
        
        server_types_sorted_by_revenue.sort(key=lambda x: x[2], reverse=True)
        #print(f"Time step {time_step}: Sorted server types by revenue: {server_types_sorted_by_revenue}")
        # 4. Allocate demand to virtual servers
        virtual_servers_needed = [] # List of (server_gen, datacenter_id, latency) tuples
        for gen, latency, _ in server_types_sorted_by_revenue:
            # 直接从输入获取需求
            #print(f"Time step {time_step}: Processing demand for {gen} with latency {latency}")
            capacity_demand = actual_demand_dict.get((gen, latency), 0)*(1+ self.elasticity.get((gen, latency), 0)*self.price_adjustment.get((gen, latency), 0))
            capacity_demand = round(capacity_demand)
            #print(f"Time step {time_step}: Processing demand for {gen} with latency {latency}, capacity demand: {capacity_demand}")
            server_info = self._get_server_info(gen)
            
            if capacity_demand <= 0:
                continue
            #改为向下取整
            required_servers_count = math.floor(capacity_demand / server_info['capacity'])

            sorted_datacenters_for_latency = self.datacenters_by_latency[latency]
            
            remaining_demand_for_gen_latency = required_servers_count

            for _, dc in sorted_datacenters_for_latency.iterrows():
                if remaining_demand_for_gen_latency <= 0:
                    break
                
                dc_id = dc['datacenter_id']
                available_slots = datacenter_remaining_slots[dc_id]
                max_servers_can_fit = available_slots // self.server_sizes[gen]
                
                allocated_servers_in_dc = min(max_servers_can_fit, remaining_demand_for_gen_latency)
                
                if allocated_servers_in_dc > 0:
                    for _ in range(allocated_servers_in_dc):
                        virtual_servers_needed.append({
                            'server_generation': gen,
                            'datacenter_id': dc_id,
                            'latency': latency
                        })

                    datacenter_remaining_slots[dc_id] -= allocated_servers_in_dc * self.server_sizes[gen]
                    remaining_demand_for_gen_latency -= allocated_servers_in_dc
        
        # 5. Match existing servers (hold operation)
        fulfilled_virtual_indices = set()
        used_existing_server_ids = set()

        # Prioritize matching existing servers in their current datacenter
        for i, virtual_demand in enumerate(virtual_servers_needed):
            if i in fulfilled_virtual_indices:
                continue
                
            gen = virtual_demand['server_generation']
            dc_id = virtual_demand['datacenter_id']
            
            if (gen, dc_id) in servers_by_type_dc:
                available_servers_in_dc = servers_by_type_dc[(gen, dc_id)]
                
                found_server_id = None
                for server_id, server_list in available_servers_in_dc.items():
                    if server_list and server_id not in used_existing_server_ids:
                        found_server_id = server_id
                        break
                
                if found_server_id:
                    fulfilled_virtual_indices.add(i)
                    used_existing_server_ids.add(found_server_id)
                    servers_by_type_dc[(gen, dc_id)][found_server_id].pop(0) 
                    if not servers_by_type_dc[(gen, dc_id)][found_server_id]:
                        del servers_by_type_dc[(gen, dc_id)][found_server_id]

        # 6. Process needs for "move" operation
        for i, virtual_demand in enumerate(virtual_servers_needed):
            if i in fulfilled_virtual_indices:
                continue
            
            gen = virtual_demand['server_generation']
            target_dc_id = virtual_demand['datacenter_id']

            # Look for an available server of the same generation from any datacenter
            found_server_to_move = None
            if gen in servers_by_type:
                for existing_server_obj in servers_by_type[gen]:
                    if existing_server_obj['server_id'] not in used_existing_server_ids:
                        found_server_to_move = existing_server_obj
                        break
            
            if found_server_to_move:
                server_info = self._get_server_info(gen)
                start_time, end_time = self._parse_release_time(server_info['release_time'])
                
                if start_time <= time_step <= end_time: # Only move if within release time
                    original_dc_id = found_server_to_move['datacenter_id']

                    # Check if target DC has space
                    if datacenter_remaining_slots[target_dc_id] >= self.server_sizes[gen]:
                        self.solution_move_step.append({
                            'time_step': time_step,
                            'datacenter_id': target_dc_id,
                            'server_generation': gen,
                            'server_id': found_server_to_move['server_id'],
                            'action': 'move'
                        })
                        
                        # Update server's datacenter in the object
                        found_server_to_move['datacenter_id'] = target_dc_id
                        used_existing_server_ids.add(found_server_to_move['server_id'])
                        fulfilled_virtual_indices.add(i)

                        # Update datacenter slots
                        datacenter_remaining_slots[original_dc_id] += self.server_sizes[gen]
                        datacenter_remaining_slots[target_dc_id] -= self.server_sizes[gen]

                        # Update the server indexing structures
                        servers_by_type_dc[(gen, original_dc_id)][found_server_to_move['server_id']].pop(0)
                        if not servers_by_type_dc[(gen, original_dc_id)][found_server_to_move['server_id']]:
                            del servers_by_type_dc[(gen, original_dc_id)][found_server_to_move['server_id']]
                        
                        servers_by_type_dc[(gen, target_dc_id)][found_server_to_move['server_id']].append(found_server_to_move)
        
        # 7. Buy new servers for unfulfilled demand
        for i, virtual_demand in enumerate(virtual_servers_needed):
            if i in fulfilled_virtual_indices:
                continue
                
            gen = virtual_demand['server_generation']
            target_dc_id = virtual_demand['datacenter_id']
            server_info = self._get_server_info(gen)
            start_time, end_time = self._parse_release_time(server_info['release_time'])
            
            # Check if target DC has space
            if datacenter_remaining_slots[target_dc_id] >= self.server_sizes[gen] and \
               start_time <= time_step <= end_time:
                
                server_id = str(uuid.uuid4())
                self.solution_buy_step.append({
                    'time_step': time_step,
                    'datacenter_id': target_dc_id,
                    'server_generation': gen,
                    'server_id': server_id,
                    'action': 'buy'
                })
                
                new_server = {
                    'server_id': server_id,
                    'server_generation': gen,
                    'datacenter_id': target_dc_id,
                    'age': 0,
                }
                self.existing_servers.append(new_server)
                datacenter_remaining_slots[target_dc_id] -= self.server_sizes[gen]

                # Mark as fulfilled
                fulfilled_virtual_indices.add(i)
                used_existing_server_ids.add(server_id)

        # 8. Dismiss still unused existing servers
        servers_to_keep = []
        for existing_server_obj in self.existing_servers:
            if existing_server_obj['server_id'] not in used_existing_server_ids:
                self.solution_dismiss_step.append({
                    'time_step': time_step,
                    'datacenter_id': existing_server_obj['datacenter_id'],
                    'server_generation': existing_server_obj['server_generation'],
                    'server_id': existing_server_obj['server_id'],
                    'action': 'dismiss'
                })
            else:
                servers_to_keep.append(existing_server_obj)

        self.existing_servers = servers_to_keep

        # 9. Record current step's solution
        self.solution.extend(self.solution_dismiss_step)
        self.solution.extend(self.solution_move_step)
        self.solution.extend(self.solution_buy_step)
        
        # 10. 添加定价策略记录
        for (gen, latency), adj_percent in self.price_adjustment.items():
            base_price_row = self.selling_prices[
                (self.selling_prices['server_generation'] == gen) &
                (self.selling_prices['latency_sensitivity'] == latency)
            ]
            
            if not base_price_row.empty:
                base_price = base_price_row.iloc[0]['selling_price']
                adjusted_price = round(base_price * (1 + adj_percent))

                self.pricing_strategy.append({
                    'time_step': time_step,
                    'latency_sensitivity': latency,
                    'server_generation': gen,
                    'price': adjusted_price
                })


def get_my_solution(demand_data):
    """Main solution function"""
    manager = ServerManager()
    
    for time_step in range(1, len(demand_data) + 1):
        time_step_demand = demand_data.get(time_step, {})
        print(f"Processing time step {time_step} ")
        manager.process_time_step(time_step, time_step_demand)
    
    return manager.solution, manager.pricing_strategy


# Main program


seeds=known_seeds()

demand = pd.read_csv('./data/demand.csv')
for seed in seeds:
    np.random.seed(seed)
    print(f"Processing seed: {seed}")
    actual_demand = get_actual_demand(demand)
    fleet_solution, pricing_solution = get_my_solution(actual_demand)

    save_solution(pd.DataFrame(fleet_solution), 
                 pd.DataFrame(pricing_solution), 
                 f'./output/{seed}.json')