diff --git a/stories/writingtogether/templates/writingtogether/index.html b/stories/writingtogether/templates/writingtogether/index.html
index 5659a4e..2ea07eb 100644
--- a/stories/writingtogether/templates/writingtogether/index.html
+++ b/stories/writingtogether/templates/writingtogether/index.html
@@ -12,8 +12,8 @@
{% if open_story_round_list %}
diff --git a/stories/writingtogether/tests/test_views.py b/stories/writingtogether/tests/test_views.py
new file mode 100644
index 0000000..ac80dbd
--- /dev/null
+++ b/stories/writingtogether/tests/test_views.py
@@ -0,0 +1,92 @@
+from django.contrib.auth import get_user_model
+from django.test import TestCase
+from django.urls import reverse
+
+from writingtogether.models import Story, StoryPart
+
+User = get_user_model()
+
+
+class TestViews(TestCase):
+ def setUp(self) -> None:
+ self.user1 = User.objects.create(username='player1')
+ self.user2 = User.objects.create(username='player2')
+
+ def test_create_story_round_two_rounds(self):
+
+ response = self.client.post(
+ reverse('writing:create_story_round'),
+ {
+ 'name': 'test round',
+ 'participants': [self.user1.pk, self.user2.pk],
+ 'number_of_rounds': 2
+ }
+ )
+
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(Story.objects.all().count(), 2)
+ self.assertEqual(StoryPart.objects.all().count(), 4)
+
+ story_user1 = Story.objects.get(started_by=self.user1.pk)
+ story_user2 = Story.objects.get(started_by=self.user2.pk)
+
+ story_user1_parts = StoryPart.objects.filter(part_of=story_user1)
+ story_user2_parts = StoryPart.objects.filter(part_of=story_user2)
+
+ # each story has 2 story parts
+ self.assertEqual(story_user1_parts.count(), 2)
+ self.assertEqual(story_user2_parts.count(), 2)
+
+ def test_create_story_round_4_rounds(self):
+
+ response = self.client.post(
+ reverse('writing:create_story_round'),
+ {
+ 'name': 'test round',
+ 'participants': [self.user1.pk, self.user2.pk],
+ 'number_of_rounds': 4
+ }
+ )
+
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(Story.objects.all().count(), 2)
+ self.assertEqual(StoryPart.objects.all().count(), 8)
+
+ story_user1 = Story.objects.get(started_by=self.user1.pk)
+ story_user2 = Story.objects.get(started_by=self.user2.pk)
+
+ story_user1_parts = StoryPart.objects.filter(part_of=story_user1)
+ story_user2_parts = StoryPart.objects.filter(part_of=story_user2)
+
+ # each story has 2 story parts
+ self.assertEqual(story_user1_parts.count(), 4)
+ self.assertEqual(story_user2_parts.count(), 4)
+
+ def test_create_story_round_three_users_two_rounds(self):
+ user3 = User.objects.create(username='player3')
+
+ response = self.client.post(
+ reverse('writing:create_story_round'),
+ {
+ 'name': 'test round',
+ 'participants': [self.user1.pk, self.user2.pk, user3.pk],
+ 'number_of_rounds': 2
+ }
+ )
+
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(Story.objects.all().count(), 3)
+ self.assertEqual(StoryPart.objects.all().count(), 6)
+
+ story_user1 = Story.objects.get(started_by=self.user1.pk)
+ story_user2 = Story.objects.get(started_by=self.user2.pk)
+ story_user3 = Story.objects.get(started_by=user3.pk)
+
+ story_user1_parts = StoryPart.objects.filter(part_of=story_user1)
+ story_user2_parts = StoryPart.objects.filter(part_of=story_user2)
+ story_user3_parts = StoryPart.objects.filter(part_of=story_user3)
+
+ # each story has 2 story parts
+ self.assertEqual(story_user1_parts.count(), 2)
+ self.assertEqual(story_user2_parts.count(), 2)
+ self.assertEqual(story_user3_parts.count(), 2)
\ No newline at end of file
diff --git a/stories/writingtogether/views.py b/stories/writingtogether/views.py
index 48a8d55..81004c6 100644
--- a/stories/writingtogether/views.py
+++ b/stories/writingtogether/views.py
@@ -31,24 +31,27 @@ class StoryRoundCreate(CreateView):
def form_valid(self, form):
self.object = form.save()
- sorted_participants = sorted(form.cleaned_data['participants'])
+ sorted_participants = sorted([user.pk for user in form.cleaned_data['participants']])
number_of_participants = len(sorted_participants)
- for user in sorted_participants:
+ for user_id in sorted_participants:
story = Story.objects.create(
part_of_round=self.object,
- started_by=user
+ started_by_id=user_id
)
previous_part=None
for i in range(form.cleaned_data['number_of_rounds']):
current_part = StoryPart.objects.create(
- user=sorted_participants[i % number_of_participants],
+ user_id=sorted_participants[i % number_of_participants],
previous_part=previous_part,
- part_of=story
+ part_of=story,
+ turn_number=0
)
previous_part = current_part
+ return HttpResponseRedirect(self.get_success_url())
+
class StoryUpdate(UpdateView):
model = Story