diff --git a/WebHostLib/tracker.py b/WebHostLib/tracker.py index 75b5fb02..5450ef51 100644 --- a/WebHostLib/tracker.py +++ b/WebHostLib/tracker.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, NamedTuple, from uuid import UUID from email.utils import parsedate_to_datetime -from flask import render_template, make_response, Response, request +from flask import make_response, render_template, request, Request, Response from werkzeug.exceptions import abort from MultiServer import Context, get_saving_second @@ -298,17 +298,25 @@ class TrackerData: return self._multidata.get("spheres", []) -def _process_if_request_valid(incoming_request, room: Optional[Room]) -> Optional[Response]: +def _process_if_request_valid(incoming_request: Request, room: Optional[Room]) -> Optional[Response]: if not room: abort(404) - if_modified = incoming_request.headers.get("If-Modified-Since", None) - if if_modified: - if_modified = parsedate_to_datetime(if_modified) + if_modified_str: Optional[str] = incoming_request.headers.get("If-Modified-Since", None) + if if_modified_str: + if_modified = parsedate_to_datetime(if_modified_str) + if if_modified.tzinfo is None: + abort(400) # standard requires "GMT" timezone + # database may use datetime.utcnow(), which is timezone-naive. convert to timezone-aware. + last_activity = room.last_activity + if last_activity.tzinfo is None: + last_activity = room.last_activity.replace(tzinfo=datetime.timezone.utc) # if_modified has less precision than last_activity, so we bring them to same precision - if if_modified >= room.last_activity.replace(microsecond=0): + if if_modified >= last_activity.replace(microsecond=0): return make_response("", 304) + return None + @app.route("/tracker///") def get_player_tracker(tracker: UUID, tracked_team: int, tracked_player: int, generic: bool = False) -> Response: diff --git a/test/webhost/data/One_Archipelago.archipelago b/test/webhost/data/One_Archipelago.archipelago new file mode 100644 index 00000000..8b7a8ce0 Binary files /dev/null and b/test/webhost/data/One_Archipelago.archipelago differ diff --git a/test/webhost/test_tracker.py b/test/webhost/test_tracker.py new file mode 100644 index 00000000..58145d77 --- /dev/null +++ b/test/webhost/test_tracker.py @@ -0,0 +1,95 @@ +import os +import pickle +from pathlib import Path +from typing import ClassVar +from uuid import UUID, uuid4 + +from flask import url_for + +from . import TestBase + + +class TestTracker(TestBase): + room_id: UUID + tracker_uuid: UUID + log_filename: str + data: ClassVar[bytes] + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + with (Path(__file__).parent / "data" / "One_Archipelago.archipelago").open("rb") as f: + cls.data = f.read() + + def setUp(self) -> None: + from pony.orm import db_session + from MultiServer import Context as MultiServerContext + from Utils import user_path + from WebHostLib.models import GameDataPackage, Room, Seed + + super().setUp() + + multidata = MultiServerContext.decompress(self.data) + + with self.client.session_transaction() as session: + session["_id"] = uuid4() + self.tracker_uuid = uuid4() + with db_session: + # store game datapackage(s) + for game, game_data in multidata["datapackage"].items(): + if not GameDataPackage.get(checksum=game_data["checksum"]): + GameDataPackage(checksum=game_data["checksum"], + data=pickle.dumps(game_data)) + # create an empty seed and a room from it + seed = Seed(multidata=self.data, owner=session["_id"]) + room = Room(seed=seed, owner=session["_id"], tracker=self.tracker_uuid) + self.room_id = room.id + self.log_filename = user_path("logs", f"{self.room_id}.txt") + + def tearDown(self) -> None: + from pony.orm import db_session, select + from WebHostLib.models import Command, Room + + with db_session: + for command in select(command for command in Command if command.room.id == self.room_id): # type: ignore + command.delete() + room: Room = Room.get(id=self.room_id) + room.seed.delete() + room.delete() + + try: + os.unlink(self.log_filename) + except FileNotFoundError: + pass + + def test_valid_if_modified_since(self) -> None: + """ + Verify that we get a 200 response for valid If-Modified-Since + """ + with self.app.app_context(), self.app.test_request_context(): + response = self.client.get( + url_for( + "get_player_tracker", + tracker=self.tracker_uuid, + tracked_team=0, + tracked_player=1, + ), + headers={"If-Modified-Since": "Wed, 21 Oct 2015 07:28:00 GMT"}, + ) + self.assertEqual(response.status_code, 200) + + def test_invalid_if_modified_since(self) -> None: + """ + Verify that we get a 400 response for invalid If-Modified-Since + """ + with self.app.app_context(), self.app.test_request_context(): + response = self.client.get( + url_for( + "get_player_tracker", + tracker=self.tracker_uuid, + tracked_team=1, + tracked_player=0, + ), + headers={"If-Modified-Since": "Wed, 21 Oct 2015 07:28:00"}, # missing timezone + ) + self.assertEqual(response.status_code, 400)